4 votes

Comment exécuter une session tensorflow à l'intérieur d'une session par défaut?

Comment utiliser la prédiction de la session par défaut de tensorflow comme entrée pour une nouvelle session tensorflow. J'ai un modèle de détection, les objets détectés devraient être passés en entrée à un nouveau modèle pour la classification, lorsque j'essaie de le faire, je reçois une erreur d'épuisement des ressources. code d'exemple:

with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        while True:
            sess.run([boxes, scores, classes, num_detections] )

            """ Je veux utiliser les valeurs prédites pour une autre session tensorflow pour la classification"""
           c'est-à-dire
           with tf.Session() as sess:
               "Modèle de classification"
               "Code pseudo????"

Merci

1voto

BigBadMe Points 1174

Après cela avec la session se fermera car elle sortira de la portée, donc oui, elle sera créée à chaque itération de la boucle while:

   with tf.Session() as sess:
       "Modèle de classification"
       "Pseudo code????"

Je soupçonne que vous voudriez re-arranger quelque chose comme ceci:

with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        sess2 = tf.Session(graph=detection_graph)
        while True:
            sess.run([boxes, scores, classes, num_detections] )

            """ Je veux utiliser les valeurs prédites pour une autre session tensorflow pour la classification"""
            # utilisez sess2 ici
            "Modèle de classification"
            "Pseudo code????"

0voto

Balaji Points 131

Après avoir essayé toutes les méthodes différentes, j'ai pu le réparer en utilisant la méthode de classe init.

def __init__(self):
    """Tensorflow détecteur
    """
    self.detection_graph = tf.Graph()
    with self.detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(frozen_inference.pb, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')

    with self.detection_graph.as_default():
        config = tf.ConfigProto()
        # config.gpu_options.allow_growth = True
        self.detection_sess = tf.Session(graph=self.detection_graph, config=config)
        self.windowNotSet = True

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X