191 votes

Que fait model.train () dans pytorch?

Appelle-t-elle forward() en nn.Module ? Je pensais que lorsque nous appelons le modèle, la méthode forward est utilisée. Pourquoi devons-nous spécifier train ()?

253voto

Umang Gupta Points 3100

model.train() indique à votre modèle que vous êtes le modèle de la formation. Donc, effectivement couches comme d'abandon, batchnorm etc. qui se comportent différents sur le train et les procédures de test de savoir ce qui se passe et peut donc se comporter en conséquence.

Plus de détails: Il définit le mode de former (voir code source). Vous pouvez appeler model.eval() ou model.train(mode=False) à dire que vous testez. Il est assez intuitif de s'attendre à ce train fonction de train de modèle, mais il ne le fait pas. Il dicte la mode.

88voto

prosti Points 4630

Voici le code de module.train():

def train(self, mode=True):
        r"""Sets the module in training mode."""      
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

Et voici l' module.eval.

def eval(self):
        r"""Sets the module in evaluation mode."""
        return self.train(False)

Les Modes d' train et eval sont les deux seuls modes, nous pouvons définir le module, et ils sont exactement à l'opposé.

C'est juste un self.training drapeau et actuellement seulement d'abandon et bachnorm de soins sur le drapeau.

Par défaut, cette option est définie sur True.

16voto

kelam gautam Points 36

Il y a deux façons de faire le modèle de connaître votre intention que j'ai.e voulez-vous vous entraîner le modèle ou voulez-vous utiliser le modèle d'évaluation. Dans le cas du modèle.train() le modèle sait qu'il doit apprendre les couches et lorsque nous utilisons le modèle.eval() il indique le modèle que rien de nouveau n'est à apprendre, et le modèle est utilisé pour les tests. de modèle.eval() est également nécessaire parce que, dans pytorch si nous sommes à l'aide de batchnorm et au cours du test, si nous voulons juste passer une seule image, pytorch renvoie une erreur si le modèle.eval() n'est pas spécifié.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X