142 votes

Dans Tensorflow, obtenir les noms de tous les tenseurs dans un graphique

Je crée des réseaux neuronaux avec Tensorflow y skflow ; pour une raison quelconque, je veux obtenir les valeurs de certains tenseurs internes pour une entrée donnée, et j'utilise donc myClassifier.get_layer_value(input, "tensorName") , myClassifier être un skflow.estimators.TensorFlowEstimator .

Cependant, je trouve difficile de trouver la syntaxe correcte du nom du tenseur, même en connaissant son nom (et je m'embrouille entre opération et tenseurs), donc j'utilise tensorboard pour tracer le graphique et chercher le nom.

Existe-t-il un moyen d'énumérer tous les tenseurs dans un graphe sans utiliser tensorboard ?

202voto

Yaroslav Bulatov Points 7316

Vous pouvez faire

[n.name for n in tf.get_default_graph().as_graph_def().node]

De plus, si vous effectuez un prototypage dans un notebook IPython, vous pouvez afficher le graphique directement dans le notebook, cf. show_graph fonction dans le rêve profond d'Alexander ordinateur portable

2 votes

Vous pouvez filtrer cela pour les variables par exemple en ajoutant if "Variable" in n.op à la fin de la compréhension.

0 votes

Y a-t-il un moyen d'obtenir un nœud spécifique si vous en connaissez le nom ?

0 votes

Pour en savoir plus sur les nœuds de graphe : tensorflow.org/extend/tool_developers/#nodes

25voto

Salvador Dali Points 11667

Il y a un moyen de le faire un peu plus rapidement que dans la réponse de Yaroslav en utilisant get_operations . Voici un exemple rapide :

import tensorflow as tf

a = tf.constant(1.3, name='const_a')
b = tf.Variable(3.1, name='variable_b')
c = tf.add(a, b, name='addition')
d = tf.multiply(c, a, name='multiply')

for op in tf.get_default_graph().get_operations():
    print(str(op.name))

3 votes

Vous ne pouvez pas obtenir des tenseurs en utilisant tf.get_operations() . La seule opération que vous pouvez obtenir.

11voto

Yuan Tang Points 61

tf.all_variables() peut vous donner les informations que vous voulez.

Aussi, cet engagement fait aujourd'hui dans TensorFlow Learn qui fournit une fonction get_variable_names dans l'estimateur que vous pouvez utiliser pour retrouver facilement tous les noms de variables.

2 votes

Cette fonction est obsolète

8 votes

... et son successeur est tf.global_variables()

11 votes

Cela ne récupère que les variables, pas les tenseurs.

5voto

Lu Howyou Points 21

Je pense que ça fera l'affaire aussi :

print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))

Mais par rapport aux réponses de Salvado et de Yaroslav, je ne sais pas laquelle est la meilleure.

0 votes

Celui-ci a fonctionné avec un graphique importé d'un fichier frozen_inference_graph.pb utilisé dans l'API de détection d'objets de tensorflow. Merci

5voto

Pepe Points 383

La réponse acceptée vous donne seulement une liste de chaînes de caractères avec les noms. Je préfère une approche différente, qui vous donne un accès (presque) direct aux tenseurs :

graph = tf.get_default_graph()
list_of_tuples = [op.values() for op in graph.get_operations()]

list_of_tuples contient maintenant tous les tenseurs, chacun dans un tuple. Vous pouvez également l'adapter pour obtenir les tenseurs directement :

graph = tf.get_default_graph()
list_of_tuples = [op.values()[0] for op in graph.get_operations()]

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X