La façon dont vous enregistrez votre modèle dépend de la façon dont vous voulez y accéder à l'avenir. Si vous pouvez appeler une nouvelle instance de la classe model
tout ce que vous avez à faire est de sauvegarder/charger les poids du modèle avec model.state_dict()
:
# Save:
torch.save(old_model.state_dict(), PATH)
# Load:
new_model = TheModelClass(*args, **kwargs)
new_model.load_state_dict(torch.load(PATH))
Si vous ne pouvez pas le faire pour une raison quelconque (ou si vous préférez une syntaxe plus simple), vous pouvez sauvegarder le modèle entier (en fait une référence au(x) fichier(s) définissant le modèle, ainsi que son state_dict) avec torch.save()
:
# Save:
torch.save(old_model, PATH)
# Load:
new_model = torch.load(PATH)
Mais comme il s'agit d'une référence à l'emplacement des fichiers définissant la classe modèle, ce code n'est pas portable à moins que ces fichiers ne soient également portés dans la même structure de répertoire.
Sauvegarde dans le nuage - TorchHub
Si vous souhaitez que votre modèle soit portable, vous pouvez facilement permettre qu'il soit importé à l'aide de la fonction torch.hub
. Si vous ajoutez un hubconf.py
dans un répertoire github, qui peut être facilement appelé à partir de PyTorch pour permettre aux utilisateurs de charger votre modèle avec ou sans poids :
hubconf.py
( github.com/repo_owner/repo_name )
dependencies = ['torch']
from my_module import mymodel as _mymodel
def mymodel(pretrained=False, **kwargs):
return _mymodel(pretrained=pretrained, **kwargs)
Modèle de chargement :
new_model = torch.hub.load('repo_owner/repo_name', 'mymodel')
new_model_pretrained = torch.hub.load('repo_owner/repo_name', 'mymodel', pretrained=True)
3 votes
Je pense que c'est parce que torch.save() enregistre toutes les variables intermédiaires ainsi, comme les sorties intermédiaires pour l'utilisation de la rétro propagation. Mais vous avez seulement besoin de sauvegarder les paramètres du modèle, comme le poids/le biais, etc. Parfois, le premier peut être beaucoup plus grand que le second.
3 votes
J'ai testé
torch.save(model, f)
ytorch.save(model.state_dict(), f)
. Les fichiers enregistrés ont la même taille. Maintenant je suis confus. De plus, j'ai trouvé l'utilisation de pickle pour sauvegarder model.state_dict() extrêmement lente. Je pense que le meilleur moyen est d'utilisertorch.save(model.state_dict(), f)
puisque vous vous occupez de la création du modèle, et que Torch s'occupe du chargement des poids du modèle, ce qui élimine les problèmes éventuels. Référence : discuss.pytorch.org/t/saving-torch-models/838/42 votes
Il semble que PyTorch ait abordé cette question de manière un peu plus explicite dans son document intitulé section des didacticiels -Vous y trouverez de nombreuses informations utiles qui ne figurent pas dans les réponses données ici, notamment la sauvegarde de plusieurs modèles à la fois et le démarrage à chaud des modèles.
0 votes
Qu'y a-t-il de mal à utiliser
pickle
?1 votes
@CharlieParker torch.save est basé sur pickle. Ce qui suit est tiré du tutoriel lié ci-dessus : "[torch.save] va sauvegarder l'ensemble du module en utilisant le module pickle de Python. L'inconvénient de cette approche est que les données sérialisées sont liées aux classes spécifiques et à la structure exacte du répertoire utilisé lorsque le modèle est sauvegardé. La raison en est que le module pickle ne sauvegarde pas la classe du modèle elle-même. Il enregistre plutôt un chemin d'accès au fichier contenant la classe, qui est utilisé lors du chargement. De ce fait, votre code peut se briser de diverses manières lorsqu'il est utilisé dans d'autres projets ou après des refactors."
0 votes
@DavidMiller en fait, j'ai seulement besoin d'enregistrer une
nn.Sequential model
. Vous savez comment faire ? Je n'ai pas de définition de classe modèle. Pour le séquentiel, j'ai écrit ceci, en espérant qu'un répondeur réputé confirmera : stackoverflow.com/questions/62923052/