358 votes

Pourquoi avons-nous besoin d'appeler zero_grad() dans PyTorch ?

La méthode zero_grad() doit être appelé pendant la formation. Mais le documentation n'est pas très utile

|  zero_grad(self)
|      Sets gradients of all model parameters to zero.

Pourquoi devons-nous appeler cette méthode ?

509voto

mario23 Points 322

Sur PyTorch nous devons mettre les gradients à zéro avant de commencer la rétropropagation car PyTorch accumule les gradients lors des passages en arrière suivants. Ceci est pratique lors de l'entraînement des RNNs. Ainsi, l'action par défaut est accumuler (c'est-à-dire additionner) les gradients sur chaque loss.backward() appeler.

Pour cette raison, lorsque vous commencez votre boucle d'entraînement, vous devriez idéalement zero out the gradients pour que vous fassiez la mise à jour des paramètres correctement. Sinon, le gradient pointerait dans une autre direction que celle prévue, vers l'objectif de l'opération. minimum (ou maximum dans le cas d'objectifs de maximisation).

Voici un exemple simple :

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

Alternativement, si vous faites un descente de gradient de vanille alors :

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

Note :

  • Le site accumulation (c'est-à-dire somme ) de gradients se produisent lorsque .backward() est appelé sur le loss tenseur .
  • Depuis la version 1.7.0, il est possible de réinitialiser les gradients à l'aide de la fonction None optimizer.zero_grad(set_to_none=True) au lieu de le remplir avec un tenseur de zéros. La documentation affirme que ce paramètre permet de réduire les besoins en mémoire et d'améliorer légèrement les performances, mais il peut être source d'erreurs s'il n'est pas utilisé avec précaution.

4 votes

Merci beaucoup, c'est vraiment utile ! Savez-vous par hasard si le tensorflow a le même comportement ?

1 votes

Juste pour être sûr si vous ne faites pas cela, alors vous rencontrerez un problème de gradient explosif, n'est-ce pas ?

28 votes

@zwep Si nous accumulons les gradients, cela ne signifie pas que leur magnitude augmente : un exemple serait si le signe du gradient ne cesse de s'inverser. Cela ne garantit donc pas que vous rencontriez le problème du gradient explosif. En outre, les gradients explosifs existent même si vous mettez correctement à zéro.

30voto

jerryIsHere Points 21

Bien que l'idée puisse être dérivée de la réponse choisie, mais j'ai envie de l'écrire explicitement.

Être capable de décider quand appeler optimizer.zero_grad() et optimizer.step() offre plus de liberté sur la façon dont le gradient est accumulé et appliqué par l'optimiseur dans la boucle d'apprentissage. Ceci est crucial lorsque le modèle ou les données d'entrée sont volumineux et qu'un lot d'entraînement réel ne tient pas dans la carte graphique.

Ici, dans cet exemple de google-research, il y a deux arguments, nommés train_batch_size et gradient_accumulation_steps .

  • train_batch_size est la taille du lot pour le passage vers l'avant, suivant la méthode de la loss.backward() . Ceci est limité par la mémoire du processeur graphique.

  • gradient_accumulation_steps est la taille réelle du lot de formation, où la perte de plusieurs passages en avant est accumulée. C'est PAS limité par la mémoire du processeur.

Dans cet exemple, vous pouvez voir comment optimizer.zero_grad() peut être suivi par optimizer.step() mais PAS loss.backward() . loss.backward() est invoqué à chaque itération (ligne 216) mais optimizer.zero_grad() et optimizer.step() n'est invoqué que lorsque le nombre de lots de train accumulés est égal au gradient_accumulation_steps (ligne 227 à l'intérieur du if à la ligne 219)

https://github.com/google-research/xtreme/blob/master/third_party/run_classify.py

Quelqu'un a également posé une question sur la méthode équivalente dans TensorFlow. Je suppose tf.GradientTape servent le même objectif.

(Je suis encore novice en matière de bibliothèque d'IA, corrigez-moi si j'ai dit quelque chose de faux).

2voto

zero_grad() relance le bouclage sans pertes à partir de la dernière étape si vous utilisez la méthode du gradient pour diminuer l'erreur (ou les pertes).

Si vous n'utilisez pas zero_grad() la perte augmentera et ne diminuera pas comme il se doit.

Par exemple :

Si vous utilisez zero_grad() vous obtiendrez le résultat suivant :

model training loss is 1.5
model training loss is 1.4
model training loss is 1.3
model training loss is 1.2

Si vous n'utilisez pas zero_grad() vous obtiendrez le résultat suivant :

model training loss is 1.4
model training loss is 1.9
model training loss is 2
model training loss is 2.8
model training loss is 3.5

9 votes

C'est pour le moins déroutant. Quelle boucle est redémarrée ? L'augmentation/diminution de la perte est affectée indirectement, elle peut augmenter quand vous faites .zero_grad() et il peut diminuer quand vous ne le faites pas. D'où viennent les sorties que vous montrez ?

0 votes

Cher dedObed (cet exemple pour si vous supprimez zero_grad de votre code correctement), nous parlons de la fonction .zero_grad() , cette fonction ne fait que commencer à boucler sans le dernier résultat si la perte est croissante est croissante vous devriez revoir votre entrée ( écrivez votre problème dans un nouveau sujet et git moi le lien.

7 votes

Je (pense que je) comprends assez bien PyTorch. Je ne fais que pointer ce que je perçois comme des défauts dans votre réponse - elle n'est pas claire, tire des conclusions rapides, montre des sorties qui savent quoi.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X