68 votes

Un moyen simple de visualiser un graphique TensorFlow dans Jupyter?

La manière officielle de visualiser un graphique TensorFlow est d'utiliser TensorBoard, mais parfois, je souhaite simplement jeter un coup d'œil rapide sur le graphique lorsque je travaille dans Jupyter.

Existe-t-il une solution rapide, idéalement basée sur les outils TensorFlow ou sur les packages SciPy standard (tels que matplotlib), mais éventuellement sur la base de bibliothèques tierces?

98voto

Yaroslav Bulatov Points 7316

Voici une recette que j'ai copié à partir de l'un de Alex Mordvintsev rêve profond ordinateur portable à un certain point

from IPython.display import clear_output, Image, display, HTML

def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = "<stripped %d bytes>"%size
    return strip_def

def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:1200px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))

Ensuite de visualiser graphe en cours

show_graph(tf.get_default_graph().as_graph_def())

Si votre graphique est enregistré comme pbtxt, vous pourriez faire

gdef = tf.GraphDef()
from google.protobuf import text_format
text_format.Merge(open("tf_persistent.pbtxt").read(), gdef)
show_graph(gdef)

Vous verrez quelque chose comme ceci

enter image description here

15voto

Liu Shengpeng Points 111

J'ai écrit une extension Jupyter pour l'intégration de tensorboard. Ça peut:

  1. Lancez tensorboard en cliquant simplement sur un bouton dans Jupyter
  2. Gérer plusieurs instances de tensorboard.
  3. Intégration transparente avec l'interface Jupyter.

Github: https://github.com/lspvic/jupyter_tensorboard

5voto

Salvador Dali Points 11667

J'ai écris un helper qui commence un tensorboard de la jupyter ordinateur portable. Il suffit d'ajouter cette fonction quelque part dans le haut de votre ordinateur portable

def TB(cleanup=False):
    import webbrowser
    webbrowser.open('http://127.0.1.1:6006')

    !tensorboard --logdir="logs"

    if cleanup:
        !rm -R logs/

Et puis exécutez - TB() chaque fois que vous avez généré votre sommaire. Au lieu d'ouvrir un graphique dans le même jupyter fenêtre, il:

  • commence une tensorboard
  • ouvre un nouvel onglet avec tensorboard
  • accédez-vous à cette onglet

Après vous avez terminé avec l'exploration, cliquez simplement sur l'onglet, et de cesser d'interrompre le noyau. Si vous voulez nettoyage de votre répertoire des journaux, après la course, il suffit d'exécuter TB(1)

4voto

skicavs Points 61

Une version gratuite de Tensorboard / iframes de cette visualisation qui s’avère encombrée rapidement peut

 import pydot
from itertools import chain
def tf_graph_to_dot(in_graph):
    dot = pydot.Dot()
    dot.set('rankdir', 'LR')
    dot.set('concentrate', True)
    dot.set_node_defaults(shape='record')
    all_ops = in_graph.get_operations()
    all_tens_dict = {k: i for i,k in enumerate(set(chain(*[c_op.outputs for c_op in all_ops])))}
    for c_node in all_tens_dict.keys():
        node = pydot.Node(c_node.name)#, label=label)
        dot.add_node(node)
    for c_op in all_ops:
        for c_output in c_op.outputs:
            for c_input in c_op.inputs:
                dot.add_edge(pydot.Edge(c_input.name, c_output.name))
    return dot
 

qui peut ensuite être suivi de

 from IPython.display import SVG
# Define model
tf_graph_to_dot(graph).write_svg('simple_tf.svg')
SVG('simple_tf.svg')
 

rendre le graphique sous forme d'enregistrements dans un fichier SVG statique Graphique Tensorflow intégré en point

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X