El cornichon La bibliothèque Python met en œuvre des protocoles binaires pour la sérialisation et la désérialisation d'un objet Python.
Quand vous import torch
(ou lorsque vous utilisez PyTorch), il s'agit de import pickle
pour vous et vous n'avez pas besoin d'appeler pickle.dump()
y pickle.load()
directement, qui sont les méthodes de sauvegarde et de chargement de l'objet.
En fait, torch.save()
y torch.load()
enveloppera pickle.dump()
y pickle.load()
pour vous.
A state_dict
l'autre réponse mentionnée mérite quelques remarques supplémentaires.
Quoi state_dict
avons-nous à l'intérieur de PyTorch ? Il y a en fait deux state_dict
s.
Le modèle PyTorch est torch.nn.Module
qui a model.parameters()
pour obtenir les paramètres apprenables (w et b). Ces paramètres apprenables, une fois fixés de manière aléatoire, seront mis à jour au fur et à mesure de l'apprentissage. Les paramètres apprenables sont les premiers state_dict
.
Le deuxième state_dict
est le dictateur d'état de l'optimiseur. Vous vous souvenez que l'optimiseur est utilisé pour améliorer nos paramètres apprenables. Mais l'optimiseur state_dict
est fixé. Il n'y a rien à apprendre.
Parce que state_dict
sont des dictionnaires Python, ils peuvent être facilement sauvegardés, mis à jour, modifiés et restaurés, ce qui ajoute une grande modularité aux modèles et optimiseurs PyTorch.
Créons un modèle très simple pour expliquer cela :
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Ce code donnera le résultat suivant :
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
Notez qu'il s'agit d'un modèle minimal. Vous pouvez essayer d'ajouter des piles de
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Notez que seules les couches dont les paramètres peuvent être appris (couches convolutionnelles, couches linéaires, etc.) et les tampons enregistrés (couches batchnorm) ont des entrées dans le modèle state_dict
.
Les éléments non-apprenables appartiennent à l'objet optimiseur. state_dict
qui contient des informations sur l'état de l'optimiseur, ainsi que sur les hyperparamètres utilisés.
La suite de l'histoire est la même ; dans la phase d'inférence (c'est une phase où nous utilisons le modèle après l'entraînement) pour la prédiction ; nous prédisons sur la base des paramètres que nous avons appris. Donc, pour l'inférence, nous avons juste besoin de sauvegarder les paramètres model.state_dict()
.
torch.save(model.state_dict(), filepath)
Et pour l'utiliser plus tard model.load_state_dict(torch.load(filepath)) modèle.eval()
Note : N'oubliez pas la dernière ligne model.eval()
ceci est crucial après le chargement du modèle.
N'essayez pas non plus de sauver torch.save(model.parameters(), filepath)
. Le site model.parameters()
est juste l'objet générateur.
D'un autre côté, torch.save(model, filepath)
enregistre l'objet modèle lui-même, mais gardez à l'esprit que le modèle ne dispose pas de l'option state_dict
. Consultez l'autre excellente réponse de @Jadiel de Armas pour sauvegarder le dict d'état de l'optimiseur.
3 votes
Je pense que c'est parce que torch.save() enregistre toutes les variables intermédiaires ainsi, comme les sorties intermédiaires pour l'utilisation de la rétro propagation. Mais vous avez seulement besoin de sauvegarder les paramètres du modèle, comme le poids/le biais, etc. Parfois, le premier peut être beaucoup plus grand que le second.
3 votes
J'ai testé
torch.save(model, f)
ytorch.save(model.state_dict(), f)
. Les fichiers enregistrés ont la même taille. Maintenant je suis confus. De plus, j'ai trouvé l'utilisation de pickle pour sauvegarder model.state_dict() extrêmement lente. Je pense que le meilleur moyen est d'utilisertorch.save(model.state_dict(), f)
puisque vous vous occupez de la création du modèle, et que Torch s'occupe du chargement des poids du modèle, ce qui élimine les problèmes éventuels. Référence : discuss.pytorch.org/t/saving-torch-models/838/42 votes
Il semble que PyTorch ait abordé cette question de manière un peu plus explicite dans son document intitulé section des didacticiels -Vous y trouverez de nombreuses informations utiles qui ne figurent pas dans les réponses données ici, notamment la sauvegarde de plusieurs modèles à la fois et le démarrage à chaud des modèles.
0 votes
Qu'y a-t-il de mal à utiliser
pickle
?1 votes
@CharlieParker torch.save est basé sur pickle. Ce qui suit est tiré du tutoriel lié ci-dessus : "[torch.save] va sauvegarder l'ensemble du module en utilisant le module pickle de Python. L'inconvénient de cette approche est que les données sérialisées sont liées aux classes spécifiques et à la structure exacte du répertoire utilisé lorsque le modèle est sauvegardé. La raison en est que le module pickle ne sauvegarde pas la classe du modèle elle-même. Il enregistre plutôt un chemin d'accès au fichier contenant la classe, qui est utilisé lors du chargement. De ce fait, votre code peut se briser de diverses manières lorsqu'il est utilisé dans d'autres projets ou après des refactors."
0 votes
@DavidMiller en fait, j'ai seulement besoin d'enregistrer une
nn.Sequential model
. Vous savez comment faire ? Je n'ai pas de définition de classe modèle. Pour le séquentiel, j'ai écrit ceci, en espérant qu'un répondeur réputé confirmera : stackoverflow.com/questions/62923052/