170 votes

pytorch - connexion entre loss.backward() et optimizer.step()

Où se trouve un lien explicite entre le optimizer et le loss ?

Comment l'optimiseur sait-il où obtenir les gradients de la perte sans un appel comme celui-ci ? optimizer.step(loss) ?

-Plus de contexte.

Lorsque je minimise la perte, je n'ai pas eu à transmettre les gradients à l'optimiseur.

loss.backward() # Back Propagation
optimizer.step() # Gardient Descent

128voto

Shai Points 24484

Sans trop entrer dans les détails internes de pytorch, je peux offrir une réponse simpliste :

Rappelons que lors de l'initialisation de optimizer vous lui indiquez explicitement quels paramètres (tenseurs) du modèle il doit mettre à jour. Les gradients sont "stockés" par les tenseurs eux-mêmes (ils ont une valeur de grad et un requires_grad ) une fois que vous avez appelé backward() sur la perte. Après avoir calculé les gradients pour tous les tenseurs du modèle, l'appel à optimizer.step() fait en sorte que l'optimiseur itère sur tous les paramètres (tenseurs) qu'il est censé mettre à jour et utilise leurs données stockées en interne. grad pour mettre à jour leurs valeurs.

Plus d'informations sur les graphes de calcul et les informations supplémentaires "grad" stockées dans les tenseurs de pytorch peuvent être trouvés dans cette réponse .

La référence aux paramètres par l'optimiseur peut parfois causer des problèmes, par exemple, lorsque le modèle est déplacé vers le GPU. après initialisation de l'optimiseur. Assurez-vous que vous avez fini de configurer votre modèle avant la construction de l'optimiseur. Voir cette réponse pour plus de détails.

55voto

Ganesh Points 311

Lorsque vous appelez loss.backward() tout ce qu'il fait, c'est calculer le gradient de la perte en fonction de tous les paramètres de la perte qui ont un rapport avec la perte. requires_grad = True et les stocker dans parameter.grad pour chaque paramètre.

optimizer.step() met à jour tous les paramètres en fonction de parameter.grad

6 votes

loss c'est calculer la perte entre deux tenseurs sans relation avec un réseau. Comment loss.backward() savoir quel réseau il doit référencer et calculer parameter.grad pour ?

0 votes

@AzizAlfoudari voir ma réponse pour une tentative de clarification :).

51voto

pseudomarvin Points 390

Cela permettra peut-être de clarifier un peu le lien entre loss.backward y optim.step (bien que les autres réponses soient pertinentes).

# Our "model"
x = torch.tensor([1., 2.], requires_grad=True)
y = 100*x

# Compute loss
loss = y.sum()

# Compute gradients of the parameters w.r.t. the loss
print(x.grad)     # None
loss.backward()      
print(x.grad)     # tensor([100., 100.])

# MOdify the parameters by subtracting the gradient
optim = torch.optim.SGD([x], lr=0.001)
print(x)        # tensor([1., 2.], requires_grad=True)
optim.step()
print(x)        # tensor([0.9000, 1.9000], requires_grad=True)

loss.backward() définit le grad de tous les tenseurs avec requires_grad=True dans le graphe computationnel dont la perte est la feuille (seulement x dans ce cas).

L'optimiseur se contente d'itérer à travers la liste des paramètres (tenseurs) qu'il a reçus lors de l'initialisation et partout où un tenseur a requires_grad=True il soustrait la valeur de son gradient stockée dans son fichier .grad (simplement multiplié par le taux d'apprentissage dans le cas de SGD). Il n'a pas besoin de savoir par rapport à quelle perte les gradients ont été calculés, il veut juste accéder à cette propriété. .grad pour qu'il puisse faire x = x - lr * x.grad

Note que si nous faisions ça dans une boucle de train, nous appellerions optim.zero_grad() car à chaque étape du train, nous voulons calculer de nouveaux gradients - nous ne nous soucions pas des gradients du lot précédent. Ne pas mettre à zéro les gradients conduirait à une accumulation de gradients entre les lots.

33voto

LollipopKnight Points 61

Certaines réponses ont bien expliqué, mais j'aimerais donner un exemple précis pour expliquer le mécanisme.

Supposons que nous ayons une fonction : z = 3 x^2 + y^3.
La formule de mise à jour du gradient de z en fonction de x et y est la suivante :

enter image description here

Les valeurs initiales sont x=1 et y=2.

x = torch.tensor([1.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = 3*x**2+y**3

print("x.grad: ", x.grad)
print("y.grad: ", y.grad)
print("z.grad: ", z.grad)

# print result should be:
x.grad:  None
y.grad:  None
z.grad:  None

Calculer ensuite le gradient de x et y en valeur courante (x=1, y=2)

enter image description here

# calculate the gradient
z.backward()

print("x.grad: ", x.grad)
print("y.grad: ", y.grad)
print("z.grad: ", z.grad)

# print result should be:
x.grad:  tensor([6.])
y.grad:  tensor([12.])
z.grad:  None

Enfin, l'optimiseur SGD est utilisé pour mettre à jour les valeurs de x et y selon la formule : enter image description here

# create an optimizer, pass x,y as the paramaters to be update, setting the learning rate lr=0.1
optimizer = optim.SGD([x, y], lr=0.1)

# executing an update step
optimizer.step()

# print the updated values of x and y
print("x:", x)
print("y:", y)

# print result should be:
x: tensor([0.4000], requires_grad=True)
y: tensor([0.8000], requires_grad=True)

29voto

Akavall Points 7357

Disons que nous avons défini un modèle : model et la fonction de perte : criterion et nous avons la séquence d'étapes suivante :

pred = model(input)
loss = criterion(pred, true_labels)
loss.backward()

pred aura un grad_fn qui fait référence à la fonction qui l'a créé, et le lie au modèle. Par conséquent, loss.backward() aura des informations sur le modèle avec lequel il travaille.

Essayez d'enlever grad_fn par exemple avec :

pred = pred.clone().detach()

Les gradients du modèle seront alors None et par conséquent les poids ne seront pas mis à jour.

Et l'optimiseur est lié au modèle parce que nous passons model.parameters() quand on crée l'optimiseur.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X