pytorch使用(三)网络结构可视化

来源:互联网 发布:百度读书软件 编辑:程序博客网 时间:2024/05/16 15:44

pytorch使用:目录


pytorch使用(三)网络结构可视化

在实际使用时,可视化网络会对调试很有帮助,借助GitHub上一位大神的代码,可以实现pytorch上网络结构的可视化:
- visualize.py

from graphviz import Digraphimport torchfrom torch.autograd import Variabledef make_dot(var, params=None):    """ Produces Graphviz representation of PyTorch autograd graph    Blue nodes are the Variables that require grad, orange are Tensors    saved for backward in torch.autograd.Function    Args:        var: output Variable        params: dict of (name, Variable) to add names to node that            require grad (TODO: make optional)    """    if params is not None:        assert isinstance(params.values()[0], Variable)        param_map = {id(v): k for k, v in params.items()}    node_attr = dict(style='filled',                     shape='box',                     align='left',                     fontsize='12',                     ranksep='0.1',                     height='0.2')    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))    seen = set()    def size_to_str(size):        return '('+(', ').join(['%d' % v for v in size])+')'    def add_nodes(var):        if var not in seen:            if torch.is_tensor(var):                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')            elif hasattr(var, 'variable'):                u = var.variable                name = param_map[id(u)] if params is not None else ''                node_name = '%s\n %s' % (name, size_to_str(u.size()))                dot.node(str(id(var)), node_name, fillcolor='lightblue')            else:                dot.node(str(id(var)), str(type(var).__name__))            seen.add(var)            if hasattr(var, 'next_functions'):                for u in var.next_functions:                    if u[0] is not None:                        dot.edge(str(id(u[0])), str(id(var)))                        add_nodes(u[0])            if hasattr(var, 'saved_tensors'):                for t in var.saved_tensors:                    dot.edge(str(id(t)), str(id(var)))                    add_nodes(t)    add_nodes(var.grad_fn)    return dot
  • 这种方法需要安装python-graphviz:
    conda install -n pytorch python-graphviz
  • 使用方法如下:
import torchfrom torch.autograd import Variablefrom MyNet import MyNetfrom visualize import  make_dotx = Variable(torch.randn(1,22,224,224))#change 12 to the channel number of network inputmodel = MyNet()y = model(x)g = make_dot(y)g.view()

可视化后的结果:
网络结构可视化

原创粉丝点击