pytorch模型可视化

环境

安装graphviz, pytorchviz

1
2
pip install graphviz
pip install git+https://github.com/szagoruyko/pytorchviz

示例:

1
2
3
4
5
6
7
8
9
10
11
12
import torch
from torch import nn
from torchviz import make_dot

from torchvision.models import AlexNet

model = AlexNet()

x = torch.randn(1, 3, 227, 227).requires_grad_(True)
y = model(x)
vis_graph = make_dot(y, params=dict(list(model.named_parameters()) + [('x', x)]))
vis_graph.view()

使用模型查看工具