Explication
Pour les réseaux de neurones, nous utilisons habituellement loss
évaluer la façon dont le réseau a appris à classer l'image d'entrée (ou d'autres tâches). L' loss
terme est généralement une valeur scalaire. Afin de mettre à jour les paramètres du réseau, nous avons besoin de calculer le gradient de l' loss
w.r.t pour les paramètres, qui est en fait leaf node
dans le graphe (par la façon dont ces paramètres sont essentiellement les poids et les biais de différentes couches de telles Convolution Linéaire, et ainsi de suite).
Selon la chaîne de la règle, afin de calculer le gradient de l' loss
w.r.t à un nœud feuille, nous pouvons calculer la dérivé de la loss
w.r.t certains variable intermédiaire, et le gradient de la variable intermédiaire w.r.t à la feuille de variable, de faire un point produit et la somme de toutes ces.
L' gradient
arguments d'un Variable
s' backward()
méthode est utilisée pour calculer une somme pondérée de chaque élément d'une Variable w.r.t la feuille de Variable. Ces poids est simplement la dérivée de finale loss
w.r.t chaque élément de la variable intermédiaire.
Un exemple concret
Prenons un béton et exemple simple pour comprendre le présent.
from torch.autograd import Variable
import torch
x = Variable(torch.FloatTensor([[1, 2, 3, 4]]), requires_grad=True)
z = 2*x
loss = z.sum(dim=1)
# do backward for first element of z
z.backward(torch.FloatTensor([[1, 0, 0, 0]]), retain_graph=True)
print(x.grad.data)
x.grad.data.zero_() #remove gradient in x.grad, or it will be accumulated
# do backward for second element of z
z.backward(torch.FloatTensor([[0, 1, 0, 0]]), retain_graph=True)
print(x.grad.data)
x.grad.data.zero_()
# do backward for all elements of z, with weight equal to the derivative of
# loss w.r.t z_1, z_2, z_3 and z_4
z.backward(torch.FloatTensor([[1, 1, 1, 1]]), retain_graph=True)
print(x.grad.data)
x.grad.data.zero_()
# or we can directly backprop using loss
loss.backward() # equivalent to loss.backward(torch.FloatTensor([1.0]))
print(x.grad.data)
Dans l'exemple ci-dessus, le résultat de la première print
est
2 0 0 0
[flambeau.FloatTensor de taille 1x4]
ce qui est exactement la dérivée de z_1 w.r.t à x.
Le résultat de la deuxième print
est :
0 2 0 0
[flambeau.FloatTensor de taille 1x4]
qui est la dérivée de z_2 w.r.t à x.
Maintenant, si l'utilisation d'un poids de [1, 1, 1, 1] pour calculer la dérivée de z w.r.t de x, le résultat est 1*dz_1/dx + 1*dz_2/dx + 1*dz_3/dx + 1*dz_4/dx
. Donc pas de surprise, la sortie de la 3e print
est:
2 2 2 2
[flambeau.FloatTensor de taille 1x4]
Il convient de noter que le poids du vecteur [1, 1, 1, 1] est exactement dérivé de la loss
w.r.t à z_1, z_2, z_3 et z_4. Les dérivés de l' loss
w.r.t x
est calculé comme suit:
d(loss)/dx = d(loss)/dz_1 * dz_1/dx + d(loss)/dz_2 * dz_2/dx + d(loss)/dz_3 * dz_3/dx + d(loss)/dz_4 * dz_4/dx
La sortie de la 4e print
est le même que le 3e print
:
2 2 2 2
[flambeau.FloatTensor de taille 1x4]