47 votes

Comment visualiser un filet dans Pytorch ?

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

Je veux visualiser resnet des modèles pytorch. Comment puis-je le faire ? J'ai essayé d'utiliser torchviz mais cela donne une erreur :

'ResNet' object has no attribute 'grad_fn'

29voto

Voici trois visualisations de graphiques différentes utilisant des outils différents.

Afin de générer des exemples de visualisation, je vais utiliser un simple RNN pour effectuer une analyse des sentiments à partir d'une tutoriel en ligne :

class RNN(nn.Module):

    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):

        super().__init__()
        self.embedding  = nn.Embedding(input_dim, embedding_dim)
        self.rnn        = nn.RNN(embedding_dim, hidden_dim)
        self.fc         = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):

        embedding       = self.embedding(text)
        output, hidden  = self.rnn(embedding)

        return self.fc(hidden.squeeze(0))

Voici le résultat si vous print() le modèle.

RNN(
  (embedding): Embedding(25002, 100)
  (rnn): RNN(100, 256)
  (fc): Linear(in_features=256, out_features=1, bias=True)
)

Vous trouverez ci-dessous les résultats de trois outils de visualisation différents.

Pour chacun d'entre eux, vous devez disposer d'une entrée fictive qui peut passer par la fonction d'entrée du modèle. forward() méthode. Une façon simple d'obtenir cette entrée est de récupérer un lot à partir de votre Dataloader, comme ceci :

batch = next(iter(dataloader_train))
yhat = model(batch.text) # Give dummy batch to forward().

Torchviz

https://github.com/szagoruyko/pytorchviz

Je crois que cet outil génère son graphique en utilisant la passe arrière, donc toutes les boîtes utilisent les composants PyTorch pour la rétropropagation.

from torchviz import make_dot

make_dot(yhat, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")

Cet outil produit le fichier de sortie suivant :

torchviz output

C'est le seul résultat qui mentionne clairement les trois couches de mon modèle, embedding , rnn et fc . Les noms des opérateurs sont tirés de la passe arrière, aussi certains d'entre eux sont-ils difficiles à comprendre.

HiddenLayer

https://github.com/waleedka/hiddenlayer

Cet outil utilise la passe avant, je crois.

import hiddenlayer as hl

transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.

graph = hl.build_graph(model, batch.text, transforms=transforms)
graph.theme = hl.graph.THEMES['blue'].copy()
graph.save('rnn_hiddenlayer', format='png')

Voici le résultat. J'aime la nuance de bleu.

hiddenlayer output

Je trouve que la sortie a trop de détails et obscurcit mon architecture. Par exemple, pourquoi unsqueeze mentionné tant de fois ?

Netron

https://github.com/lutzroeder/netron

Cet outil est une application de bureau pour Mac, Windows et Linux. Il repose sur le fait que le modèle a d'abord été exporté en Format ONNX . L'application lit ensuite le fichier ONNX et en effectue le rendu. Il est ensuite possible d'exporter le modèle vers un fichier image.

input_names = ['Sentence']
output_names = ['yhat']
torch.onnx.export(model, batch.text, 'rnn.onnx', input_names=input_names, output_names=output_names)

Voici à quoi ressemble le modèle dans l'application. Je pense que cet outil est assez astucieux : vous pouvez zoomer et faire des panoramiques, et vous pouvez explorer les couches et les opérateurs. Le seul point négatif que j'ai trouvé est qu'il ne fait que des mises en page verticales.

Netron screenshot

25voto

Shai Points 24484

make_dot attend une variable (c'est-à-dire un tenseur avec grad_fn ), et non le modèle lui-même.
essayez :

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out)  # plot graph of variable, not of a nn.Module

14voto

David James Points 8344

Vous pouvez jeter un coup d'œil à PyTorchViz ( https://github.com/szagoruyko/pytorchviz ), "Un petit paquet pour créer des visualisations de graphiques et de traces d'exécution PyTorch".

Example PyTorchViz visualization

8voto

Charlie Parker Points 715

Voici comment procéder avec torchviz si vous voulez sauvegarder l'image :

# http://www.bnikolic.co.uk/blog/pytorch-detach.html

import torch
from torchviz import make_dot

x=torch.ones(10, requires_grad=True)
weights = {'x':x}

y=x**2
z=x**3
r=(y+z).sum()

make_dot(r).render("attached", format="png")

capture d'écran de l'image que vous obtenez :

enter image description here

source : http://www.bnikolic.co.uk/blog/pytorch-detach.html

1voto

Sushant Points 33

Vous pouvez utiliser TensorBoard pour la visualisation. TensorBoard est désormais entièrement pris en charge dans la version 1.2.0 de PyTorch. Plus d'informations : https://pytorch.org/docs/stable/tensorboard.html

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X