102 votes

Enregistrement de TensorFlow dans un fichier/chargement d'un graphe depuis un fichier

D'après ce que j'ai compris jusqu'à présent, il existe plusieurs façons différentes de transférer un graphe TensorFlow dans un fichier et de le charger ensuite dans un autre programme, mais je n'ai pas été en mesure de trouver des exemples/des informations claires sur leur fonctionnement. Voici ce que je sais déjà :

  1. Sauvegarder les variables du modèle dans un fichier de point de contrôle (.ckpt) en utilisant l'option tf.train.Saver() et les restaurer plus tard ( source )
  2. Sauvegarder un modèle dans un fichier .pb et le charger à nouveau en utilisant tf.train.write_graph() y tf.import_graph_def() ( source )
  3. Charger un modèle à partir d'un fichier .pb, le réentraîner, et le transférer dans un nouveau fichier .pb en utilisant Bazel ( source )
  4. Geler le graphique pour sauvegarder le graphique et les poids ensemble ( source )
  5. Utilice as_graph_def() pour sauvegarder le modèle, et pour les poids/variables, les convertir en constantes ( source )

Cependant, je n'ai pas été en mesure d'éclaircir plusieurs questions concernant ces différentes méthodes :

  1. En ce qui concerne les fichiers de contrôle, sauvegardent-ils uniquement les poids formés d'un modèle ? Les fichiers de contrôle peuvent-ils être chargés dans un nouveau programme et être utilisés pour exécuter le modèle, ou servent-ils simplement à sauvegarder les poids d'un modèle à un certain moment/étape ?
  2. Concernant tf.train.write_graph() les poids/variables sont-ils également sauvegardés ?
  3. En ce qui concerne Bazel, peut-il seulement sauvegarder dans/charger à partir de fichiers .pb pour le recyclage ? Existe-t-il une commande Bazel simple permettant de transférer un graphique dans un fichier .pb ?
  4. En ce qui concerne la congélation, peut-on charger un graphique congelé à l'aide de la fonction tf.import_graph_def() ?
  5. La démo Android de TensorFlow charge le modèle Inception de Google à partir d'un fichier .pb. Si je voulais substituer mon propre fichier .pb, comment m'y prendrais-je ? Devrais-je modifier le code/méthodes natives ?
  6. En général, quelle est la différence exacte entre toutes ces méthodes ? Ou, plus largement, quelle est la différence entre as_graph_def() /.ckpt/.pb ?

En bref, ce que je cherche, c'est une méthode pour sauvegarder à la fois un graphique (comme dans, les différentes opérations et autres) et ses poids/variables dans un fichier, qui peut ensuite être utilisé pour charger le graphique et les poids dans un autre programme, pour l'utiliser (pas nécessairement pour continuer/se recycler).

La documentation sur ce sujet n'est pas très claire, aussi toute réponse/information serait-elle grandement appréciée.

2 votes

L'API la plus récente et la plus complète est le métagraphe, qui vous permet de sauvegarder les trois éléments à la fois : 1) le graphique, 2) les valeurs des paramètres et 3) les collections : tensorflow.org/versions/r0.10/how_tos/meta_graph/index.html

82voto

mrry Points 1

Il existe de nombreuses façons d'aborder le problème de l'enregistrement d'un modèle dans TensorFlow, ce qui peut le rendre un peu confus. Prenons chacune de vos sous-questions à tour de rôle :

  1. Les fichiers de points de contrôle (produits par exemple en appelant saver.save() en un tf.train.Saver ) ne contiennent que les poids, et toute autre variable définie dans le même programme. Pour les utiliser dans un autre programme, vous devez recréer la structure de graphe associée (par exemple en exécutant du code pour la construire à nouveau, ou en appelant la commande tf.import_graph_def() ), qui indique à TensorFlow ce qu'il faut faire avec ces poids. Notez que l'appel à saver.save() produit également un fichier contenant un MetaGraphDef qui contient un graphique et des détails sur la façon d'associer les poids d'un point de contrôle à ce graphique. Voir le tutoriel pour plus de détails.

  2. tf.train.write_graph() n'écrit que la structure du graphe, pas les poids.

  3. Bazel n'est pas lié à la lecture ou à l'écriture de graphes TensorFlow. (Peut-être ai-je mal compris votre question : n'hésitez pas à la préciser dans un commentaire).

  4. Un graphique gelé peut être chargé en utilisant tf.import_graph_def() . Dans ce cas, les poids sont (généralement) intégrés au graphe, de sorte que vous n'avez pas besoin de charger un point de contrôle séparé.

  5. Le principal changement consisterait à mettre à jour les noms du ou des tenseurs qui sont introduits dans le modèle, et les noms du ou des tenseurs qui sont extraits du modèle. Dans la démo de TensorFlow Android, cela correspondrait aux éléments suivants inputName y outputName qui sont transmises à TensorFlowClassifier.initializeTensorFlow() .

  6. El GraphDef est la structure du programme, qui ne change généralement pas au cours du processus de formation. Le point de contrôle est un instantané de l'état d'un processus de formation, qui change généralement à chaque étape du processus de formation. Par conséquent, TensorFlow utilise différents formats de stockage pour ces types de données, et l'API de bas niveau fournit différentes manières de les enregistrer et de les charger. Les bibliothèques de niveau supérieur, telles que la bibliothèque MetaGraphDef bibliothèques, Keras y skflow s'appuient sur ces mécanismes pour fournir des moyens plus pratiques de sauvegarder et de restaurer un modèle entier.

0 votes

Cela signifie-t-il que le Documentation sur l'API C++ ment, quand il est dit que vous pouvez charger le graphique sauvegardé avec tf.train.write_graph() et ensuite l'exécuter ?

2 votes

La documentation de l'API C++ ne ment pas, mais il lui manque quelques détails. Le détail le plus important est que, en plus de la fonction GraphDef sauvé par tf.train.write_graph() vous devez également vous souvenir des noms des tenseurs que vous souhaitez alimenter et récupérer lors de l'exécution du graphe (point 5 ci-dessus).

0 votes

@mrry : J'ai essayé d'utiliser l'exemple DeepDream de tensorflows, mais il semble qu'il ait besoin de modèles pré-entraînés au format pb ! J'ai lancé l'exemple Cifar10, mais il ne crée que des points de contrôle ! Je n'ai pas pu trouver de fichiers pb ou quoi que ce soit d'autre ! Comment puis-je convertir mes points de contrôle au format pb que l'exemple DeepDream utilise ?

1voto

Vous pouvez essayer le code suivant :

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X