50 votes

Tensorflow : Comment obtenir un tenseur par son nom ?

J'ai du mal à récupérer un tenseur par son nom, je ne sais même pas si c'est possible.

J'ai une fonction qui crée mon graphique :

def create_structure(tf, x, input_size,dropout):    
 with tf.variable_scope("scale_1") as scope:
  W_S1_conv1 = deep_dive.weight_variable_scaling([7,7,3,64], name='W_S1_conv1')
  b_S1_conv1 = deep_dive.bias_variable([64])
  S1_conv1 = tf.nn.relu(deep_dive.conv2d(x_image, W_S1_conv1,strides=[1, 2, 2, 1], padding='SAME') + b_S1_conv1, name="Scale1_first_relu")
.
.
.
return S3_conv1,regularizer

Je veux accéder à la variable S1_conv1 en dehors de cette fonction. J'ai essayé :

with tf.variable_scope('scale_1') as scope_conv: 
 tf.get_variable_scope().reuse_variables()
 ft=tf.get_variable('Scale1_first_relu')

Mais cela me donne une erreur :

Erreur de valeur : Sous-partage : La variable scale_1/Scale1_first_relu n'existe pas, disallowed. Voulez-vous dire que vous avez défini reuse=None dans VarScope ?

Mais ça marche :

with tf.variable_scope('scale_1') as scope_conv: 
 tf.get_variable_scope().reuse_variables()
 ft=tf.get_variable('W_S1_conv1')

Je peux contourner ce problème avec

return S3_conv1,regularizer, S1_conv1

mais je ne veux pas faire ça.

Je pense que mon problème est que S1_conv1 n'est pas vraiment une variable, c'est juste un tenseur. Existe-t-il un moyen de faire ce que je veux ?

58voto

apfalz Points 608

Il existe une fonction tf.Graph.get_tensor_by_name(). Par exemple :

import tensorflow as tf

c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d, name='example')

with tf.Session() as sess:
    test =  sess.run(e)
    print e.name #example:0
    test = tf.get_default_graph().get_tensor_by_name("example:0")
    print test #Tensor("example:0", shape=(2, 2), dtype=float32)

3 votes

Pour référence, si vous avez besoin d'obtenir un OP au lieu d'un tenseur : stackoverflow.com/questions/42685994/

1 votes

@apfalz Comment arriver à "exemple:0" pour le nom du tenseur ?

1 votes

Je ne suis pas sûr d'avoir compris votre question mais, dans l'exemple ci-dessus, si vous imprimez e.name vous sauriez que le nom est example:0 . Tensorflow ajoute le :0 au nom que vous avez spécifié.

34voto

Yaroslav Bulatov Points 7316

Tous les tenseurs ont des noms de chaîne que vous pouvez voir comme suit

[tensor.name for tensor in tf.get_default_graph().as_graph_def().node]

Une fois que vous connaissez le nom, vous pouvez récupérer le Tensor en utilisant <name>:0 (0 fait référence au point d'arrivée, ce qui est quelque peu redondant).

Par exemple, si vous faites ceci

tf.constant(1)+tf.constant(2)

Vous avez les noms de Tensor suivants

[u'Const', u'Const_1', u'add']

Vous pouvez donc récupérer le résultat de l'addition comme suit

sess.run('add:0')

Remarque : cette partie ne fait pas partie de l'API publique. Les noms de tenseurs à chaîne générés automatiquement sont un détail d'implémentation et peuvent changer.

1 votes

Merci pour votre aide, mais ce n'est pas vraiment mon problème. Mon tenseur a été explicitement nommé ""Scale1_first_relu"", mais je ne peux pas obtenir une référence à celui-ci en dehors de la fonction où il a été déclaré. Je peux cependant obtenir une référence à la variable "W_S1_conv1". Les tenseurs sont-ils locaux ? Existent-ils en dehors de la fonction où ils sont créés ?

0 votes

Les tenseurs existent dans un graphique. Si vous utilisez un Graph par défaut, il est partagé entre toutes les fonctions du même thread. Le site Scale1_first_relu est plus une "suggestion" qu'un nom réel. Le nom réel dans un graphe peut avoir un préfixe s'il a été créé à l'intérieur d'une portée ou il peut avoir un suffixe automatiquement ajouté pour la déduplication. Imprimez simplement le nom des tenseurs dans le graphe en utilisant la recette ci-dessus et recherchez les tenseurs contenant Scale1_first_relu chaîne de caractères

0 votes

J'ai déjà essayé. Le nom complet est "scale_1 \Scale1_first_relu :0". La recherche de cet élément me donne toujours une erreur, mais la recherche de 'scale_1 \W_S1_conv1 fonctionne. C'est comme si le tenseur n'existait plus.

0voto

Kislay Kunal Points 61

Tout ce que tu dois faire dans ce cas est :

ft=tf.get_variable('scale1/Scale1_first_relu:0')

0voto

Kris Stern Points 570

Ou, plus simplement encore, le déduire de la correspondance .pbtxt qui est fourni avec le modèle .pb fichier. Comme cela dépend du modèle, chaque cas est différent.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X