139 votes

Pytorch, quels sont les arguments du dégradé

Je lis la documentation de PyTorch et trouve un exemple où ils écrivent

 gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)
 

où x était une variable initiale à partir de laquelle y a été construit (un vecteur 3). La question est de savoir quels sont les arguments 0.1, 1.0 et 0.0001 du tenseur des gradients. La documentation n'est pas très claire à ce sujet.

120voto

jdhao Points 3136

Explication

Pour les réseaux de neurones, nous utilisons habituellement loss évaluer la façon dont le réseau a appris à classer l'image d'entrée (ou d'autres tâches). L' loss terme est généralement une valeur scalaire. Afin de mettre à jour les paramètres du réseau, nous avons besoin de calculer le gradient de l' loss w.r.t pour les paramètres, qui est en fait leaf node dans le graphe (par la façon dont ces paramètres sont essentiellement les poids et les biais de différentes couches de telles Convolution Linéaire, et ainsi de suite).

Selon la chaîne de la règle, afin de calculer le gradient de l' loss w.r.t à un nœud feuille, nous pouvons calculer la dérivé de la loss w.r.t certains variable intermédiaire, et le gradient de la variable intermédiaire w.r.t à la feuille de variable, de faire un point produit et la somme de toutes ces.

L' gradient arguments d'un Variables' backward() méthode est utilisée pour calculer une somme pondérée de chaque élément d'une Variable w.r.t la feuille de Variable. Ces poids est simplement la dérivée de finale loss w.r.t chaque élément de la variable intermédiaire.

Un exemple concret

Prenons un béton et exemple simple pour comprendre le présent.

from torch.autograd import Variable
import torch
x = Variable(torch.FloatTensor([[1, 2, 3, 4]]), requires_grad=True)
z = 2*x
loss = z.sum(dim=1)

# do backward for first element of z
z.backward(torch.FloatTensor([[1, 0, 0, 0]]), retain_graph=True)
print(x.grad.data)
x.grad.data.zero_() #remove gradient in x.grad, or it will be accumulated

# do backward for second element of z
z.backward(torch.FloatTensor([[0, 1, 0, 0]]), retain_graph=True)
print(x.grad.data)
x.grad.data.zero_()

# do backward for all elements of z, with weight equal to the derivative of
# loss w.r.t z_1, z_2, z_3 and z_4
z.backward(torch.FloatTensor([[1, 1, 1, 1]]), retain_graph=True)
print(x.grad.data)
x.grad.data.zero_()

# or we can directly backprop using loss
loss.backward() # equivalent to loss.backward(torch.FloatTensor([1.0]))
print(x.grad.data)    

Dans l'exemple ci-dessus, le résultat de la première print est

2 0 0 0
[flambeau.FloatTensor de taille 1x4]

ce qui est exactement la dérivée de z_1 w.r.t à x.

Le résultat de la deuxième print est :

0 2 0 0
[flambeau.FloatTensor de taille 1x4]

qui est la dérivée de z_2 w.r.t à x.

Maintenant, si l'utilisation d'un poids de [1, 1, 1, 1] pour calculer la dérivée de z w.r.t de x, le résultat est 1*dz_1/dx + 1*dz_2/dx + 1*dz_3/dx + 1*dz_4/dx. Donc pas de surprise, la sortie de la 3e print est:

2 2 2 2
[flambeau.FloatTensor de taille 1x4]

Il convient de noter que le poids du vecteur [1, 1, 1, 1] est exactement dérivé de la loss w.r.t à z_1, z_2, z_3 et z_4. Les dérivés de l' loss w.r.t x est calculé comme suit:

d(loss)/dx = d(loss)/dz_1 * dz_1/dx + d(loss)/dz_2 * dz_2/dx + d(loss)/dz_3 * dz_3/dx + d(loss)/dz_4 * dz_4/dx

La sortie de la 4e print est le même que le 3e print:

2 2 2 2
[flambeau.FloatTensor de taille 1x4]

56voto

Gu Wang Points 632

En règle générale, votre calcul graphique a un scalaire sortie, dit - loss. Ensuite, vous pouvez calculer le gradient de l' loss w.r.t. le poids (w) en loss.backward(). D'où l'argument par défaut de backward() est 1.0.

Si votre sortie a plusieurs valeurs (par exemple, loss=[loss1, loss2, loss3]), vous pouvez calculer le gradient de la perte de w.r.t. le poids en loss.backward(torch.FloatTensor([1.0, 1.0, 1.0])).

En outre, si vous souhaitez ajouter des poids ou des importances différentes pertes, vous pouvez utiliser loss.backward(torch.FloatTensor([-0.1, 1.0, 0.0001])).

Cela signifie que pour calculer -0.1*d(loss1)/dw, d(loss2)/dw, 0.0001*d(loss3)/dw simultanément.

30voto

greenberet123 Points 558

Ici, la sortie de l'avant(), c'est à dire y est un 3-vecteur.

Les trois valeurs sont dégradés à la sortie du réseau. Ils sont généralement mis à 1,0 si y est le résultat final, mais peut avoir d'autres valeurs, en particulier si y est une partie d'un grand réseau.

Pour eg. si x est l'entrée, y = [y1, y2, y3] est un intermédiaire de la sortie qui est utilisé pour calculer le résultat final z,

Ensuite,

dz/dx = dz/dy1 * dy1/dx + dz/dy2 * dy2/dx + dz/dy3 * dy3/dx

Donc, ici, les trois valeurs à l'arrière sont

[dz/dy1, dz/dy2, dz/dy3]

et puis vers l'arrière() calcule dz/dx

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X