381 votes

La meilleure façon de sauvegarder un modèle formé dans PyTorch ?

Je cherchais des moyens alternatifs pour sauvegarder un modèle formé dans PyTorch. Jusqu'à présent, j'ai trouvé deux alternatives.

  1. torche.save() pour sauvegarder un modèle et torche.load() pour charger un modèle.
  2. modèle.state_dict() pour sauvegarder un modèle formé et modèle.load_state_dict() pour charger le modèle sauvegardé.

Je suis tombé sur ceci discussion où l'approche 2 est recommandée par rapport à l'approche 1.

Ma question est la suivante : pourquoi la deuxième approche est-elle préférable ? Est-ce seulement parce que torche.nn les modules ont ces deux fonctions et nous sommes encouragés à les utiliser ?

3 votes

Je pense que c'est parce que torch.save() enregistre toutes les variables intermédiaires ainsi, comme les sorties intermédiaires pour l'utilisation de la rétro propagation. Mais vous avez seulement besoin de sauvegarder les paramètres du modèle, comme le poids/le biais, etc. Parfois, le premier peut être beaucoup plus grand que le second.

3 votes

J'ai testé torch.save(model, f) y torch.save(model.state_dict(), f) . Les fichiers enregistrés ont la même taille. Maintenant je suis confus. De plus, j'ai trouvé l'utilisation de pickle pour sauvegarder model.state_dict() extrêmement lente. Je pense que le meilleur moyen est d'utiliser torch.save(model.state_dict(), f) puisque vous vous occupez de la création du modèle, et que Torch s'occupe du chargement des poids du modèle, ce qui élimine les problèmes éventuels. Référence : discuss.pytorch.org/t/saving-torch-models/838/4

2 votes

Il semble que PyTorch ait abordé cette question de manière un peu plus explicite dans son document intitulé section des didacticiels -Vous y trouverez de nombreuses informations utiles qui ne figurent pas dans les réponses données ici, notamment la sauvegarde de plusieurs modèles à la fois et le démarrage à chaud des modèles.

385voto

dontloo Points 2754

J'ai trouvé cette page sur leur repo github, je vais juste coller le contenu ici.


Approche recommandée pour la sauvegarde d'un modèle

Il existe deux approches principales pour sérialiser et restaurer un modèle.

La première (recommandée) sauvegarde et charge uniquement les paramètres du modèle :

torch.save(the_model.state_dict(), PATH)

Puis plus tard :

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

La seconde sauvegarde et charge le modèle entier :

torch.save(the_model, PATH)

Puis plus tard :

the_model = torch.load(PATH)

Cependant, dans ce cas, les données sérialisées sont liées aux classes spécifiques spécifiques et à la structure exacte du répertoire utilisé, ce qui fait qu'elles peuvent se briser de diverses manières lorsqu'elles sont lorsqu'elles sont utilisées dans d'autres projets, ou après de sérieuses refactorisations.

28 votes

Selon @smth discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/ le modèle se recharge pour former le modèle par défaut. il faut donc appeler manuellement the_model.eval() après le chargement, si vous le chargez pour l'inférence et non pour reprendre la formation.

0 votes

La deuxième méthode donne stackoverflow.com/questions/53798009/ erreur sur Windows 10. Je n'ai pas réussi à la résoudre.

0 votes

Existe-t-il une option pour sauvegarder sans avoir besoin d'un accès pour la classe modèle ?

256voto

ASDF Points 123

Cela dépend de ce que vous voulez faire.

Cas # 1 : Sauvegarder le modèle pour l'utiliser soi-même pour l'inférence : Vous sauvegardez le modèle, vous le restaurez, puis vous passez le modèle en mode évaluation. Ceci est fait parce que vous avez généralement BatchNorm y Dropout les couches qui, par défaut, sont en mode train sur la construction :

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Cas # 2 : Sauvegarder le modèle pour reprendre l'entraînement plus tard : Si vous devez continuer à former le modèle que vous êtes sur le point de sauvegarder, vous devez sauvegarder plus que le modèle. Vous devez également sauvegarder l'état de l'optimiseur, les époques, le score, etc. Vous devez procéder de la manière suivante :

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Pour reprendre la formation, vous feriez des choses comme : state = torch.load(filepath) et ensuite, pour restaurer l'état de chaque objet individuel, quelque chose comme ceci :

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Puisque vous reprenez l'entraînement, NE PAS appelez model.eval() une fois que vous aurez restauré les états lors du chargement.

Cas n° 3 : Modèle destiné à être utilisé par quelqu'un d'autre n'ayant pas accès à votre code : Dans Tensorflow, vous pouvez créer un .pb qui définit à la fois l'architecture et les poids du modèle. C'est très pratique, surtout quand on utilise Tensorflow serve . La façon équivalente de faire ceci dans Pytorch serait :

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Cette méthode n'est pas encore infaillible et comme pytorch subit encore de nombreux changements, je ne la recommande pas.

4 votes

Y a-t-il une fin de dossier recommandée pour les 3 cas ? Ou est-ce que c'est toujours .pth ?

4 votes

Dans l'affaire n°3 torch.load renvoie juste un OrderedDict. Comment obtenir le modèle afin de faire des prédictions ?

0 votes

Bonjour, Puis-je savoir comment faire le "Cas # 2 : Sauvegarder le modèle pour reprendre la formation plus tard" ? J'ai réussi à charger le checkpoint dans le modèle, puis je n'ai pas pu exécuter ou reprendre l'entraînement du modèle comme "model.to(device) model = train_model_epoch(model, criterion, optimizer, sched, epochs)".

36voto

prosti Points 4630

El cornichon La bibliothèque Python met en œuvre des protocoles binaires pour la sérialisation et la désérialisation d'un objet Python.

Quand vous import torch (ou lorsque vous utilisez PyTorch), il s'agit de import pickle pour vous et vous n'avez pas besoin d'appeler pickle.dump() y pickle.load() directement, qui sont les méthodes de sauvegarde et de chargement de l'objet.

En fait, torch.save() y torch.load() enveloppera pickle.dump() y pickle.load() pour vous.

A state_dict l'autre réponse mentionnée mérite quelques remarques supplémentaires.

Quoi state_dict avons-nous à l'intérieur de PyTorch ? Il y a en fait deux state_dict s.

Le modèle PyTorch est torch.nn.Module qui a model.parameters() pour obtenir les paramètres apprenables (w et b). Ces paramètres apprenables, une fois fixés de manière aléatoire, seront mis à jour au fur et à mesure de l'apprentissage. Les paramètres apprenables sont les premiers state_dict .

Le deuxième state_dict est le dictateur d'état de l'optimiseur. Vous vous souvenez que l'optimiseur est utilisé pour améliorer nos paramètres apprenables. Mais l'optimiseur state_dict est fixé. Il n'y a rien à apprendre.

Parce que state_dict sont des dictionnaires Python, ils peuvent être facilement sauvegardés, mis à jour, modifiés et restaurés, ce qui ajoute une grande modularité aux modèles et optimiseurs PyTorch.

Créons un modèle très simple pour expliquer cela :

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Ce code donnera le résultat suivant :

Model's state_dict:
weight      torch.Size([2, 5])
bias      torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state      {}
param_groups      [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Notez qu'il s'agit d'un modèle minimal. Vous pouvez essayer d'ajouter des piles de

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Notez que seules les couches dont les paramètres peuvent être appris (couches convolutionnelles, couches linéaires, etc.) et les tampons enregistrés (couches batchnorm) ont des entrées dans le modèle state_dict .

Les éléments non-apprenables appartiennent à l'objet optimiseur. state_dict qui contient des informations sur l'état de l'optimiseur, ainsi que sur les hyperparamètres utilisés.

La suite de l'histoire est la même ; dans la phase d'inférence (c'est une phase où nous utilisons le modèle après l'entraînement) pour la prédiction ; nous prédisons sur la base des paramètres que nous avons appris. Donc, pour l'inférence, nous avons juste besoin de sauvegarder les paramètres model.state_dict() .

torch.save(model.state_dict(), filepath)

Et pour l'utiliser plus tard model.load_state_dict(torch.load(filepath)) modèle.eval()

Note : N'oubliez pas la dernière ligne model.eval() ceci est crucial après le chargement du modèle.

N'essayez pas non plus de sauver torch.save(model.parameters(), filepath) . Le site model.parameters() est juste l'objet générateur.

D'un autre côté, torch.save(model, filepath) enregistre l'objet modèle lui-même, mais gardez à l'esprit que le modèle ne dispose pas de l'option state_dict . Consultez l'autre excellente réponse de @Jadiel de Armas pour sauvegarder le dict d'état de l'optimiseur.

2 votes

Bien qu'il ne s'agisse pas d'une solution directe, l'essence du problème est analysée en profondeur ! Upvote.

23voto

harsh Points 341

Une convention commune à PyTorch est d'enregistrer les modèles en utilisant une extension de fichier .pt ou .pth.

Sauvegarder/charger le modèle entier

Économisez :

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Charge :

(La classe de modèle doit être définie quelque part)

model.load_state_dict(torch.load(PATH))
model.eval()

0 votes

Il a soulevé : AttributeError : L'objet 'dict' n'a pas d'attribut 'eval'.

16voto

Joy Mazumder Points 58

Si vous voulez sauvegarder le modèle et voulez reprendre la formation plus tard :

Un seul GPU : Économisez :

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Charge :

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

GPU multiples : Sauvez

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Charge :

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X