PyTorch学习总结(三)——ONNX
来源:互联网 发布:阿里云域名解析设置 编辑:程序博客网 时间:2024/06/05 13:29
1.什么是ONNX
Open Neural Network Exchange (ONNX)是开放生态系统的第一步,它使人工智能开发人员可以在项目的发展过程中选择合适的工具;ONNX为AI models提供了一种开源格式。它定义了一个可以扩展的计算图模型,同时也定义了内置操作符和标准数据类型。最初我们关注的是推理(评估)所需的能力。
Caffe2, PyTorch, Microsoft Cognitive Toolkit, Apache MXNet 和其他工具都在对ONNX进行支持。在不同的框架之间实现互操作性,并简化从研究到产品化的过程,将提高人工智能社区的创新速度。
2.torch.onnx
本文我们将主要介绍PyTorch中自带的torch.onnx模块。该模块包含将模型导出到ONNX IR格式的函数。这些模型可以被ONNX库加载,然后将它们转换成可在其他深度学习框架上运行的模型。
3 End-to-end AlexNet from PyTorch to Caffe2
这里有一个简单的脚本,它将torchvision中预训练的AlexNet导出为ONNX。它运行一个简单的推断,然后将生成的跟踪模型保存到alexnet.proto
:
from torch.autograd import Variableimport torch.onnximport torchvisiondummy_input = Variable(torch.randn(10, 3, 224, 224)).cuda()model = torchvision.models.alexnet(pretrained=True).cuda()torch.onnx.export(model, dummy_input, "alexnet.proto", verbose=True)
alexnet.proto
是一个二进制的protobuf文件,它包含您导出的模型的网络结构和参数(在这里,模型是AlexNet)。关键参数verbose=True
使exporter可以打印出一种人类可读的网络表示:
# All parameters are encoded explicitly as inputs. By convention,# learned parameters (ala nn.Module.state_dict) are first, and the# actual inputs are last.graph(%1 : Float(64, 3, 11, 11) %2 : Float(64) # The definition sites of all variables are annotated with type # information, specifying the type and size of tensors. # For example, %3 is a 192 x 64 x 5 x 5 tensor of floats. %3 : Float(192, 64, 5, 5) %4 : Float(192) # ---- omitted for brevity ---- %15 : Float(1000, 4096) %16 : Float(1000) %17 : Float(10, 3, 224, 224)) { # the actual input! # Every statement consists of some output tensors (and their types), # the operator to be run (with its attributes, e.g., kernels, strides, # etc.), its input tensors (%17, %1) %19 : UNKNOWN_TYPE = Conv[kernels=[11, 11], strides=[4, 4], pads=[2, 2, 2, 2], dilations=[1, 1], group=1](%17, %1), uses = [[%20.i0]]; # UNKNOWN_TYPE: sometimes type information is not known. We hope to eliminate # all such cases in a later release. %20 : Float(10, 64, 55, 55) = Add[broadcast=1, axis=1](%19, %2), uses = [%21.i0]; %21 : Float(10, 64, 55, 55) = Relu(%20), uses = [%22.i0]; %22 : Float(10, 64, 27, 27) = MaxPool[kernels=[3, 3], pads=[0, 0, 0, 0], dilations=[1, 1], strides=[2, 2]](%21), uses = [%23.i0]; # ... # Finally, a network returns some tensors return (%58);}
你还可以使用onnx库来验证protobuf。你可以用conda安装onnx
:
conda install -c conda-forge onnx
然后,你可以运行:
import onnx# Load the ONNX modelmodel = onnx.load("alexnet.proto")# Check that the IR is well formedonnx.checker.check_model(model)# Print a human readable representation of the graphonnx.helper.printable_graph(model.graph)
为了运行导出的caffe2版本的脚本,你需要以下两项支持:
- 你需要安装caffe2。如果你还没有安装,请按照以下说明进行安装:https://caffe2.ai/docs/getting-started.html。
你需要安装onnx-caffe2,一个纯Python库,它为ONNX提供了一个caffe2的编译器。你可以用pip安装onnx-caffe2:
pip install onnx-caffe2
安装好以上依赖后,你可以使用针对Caffe2的编译器了:
# ...continuing from aboveimport onnx_caffe2.backend as backendimport numpy as np# or "CPU"rep = backend.prepare(model, device="CUDA:0") # For the Caffe2 backend:# rep.predict_net is the Caffe2 protobuf for the network# rep.workspace is the Caffe2 workspace for the network# (see the class onnx_caffe2.backend.Workspace)outputs = rep.run(np.random.randn(10, 3, 224, 224).astype(np.float32))# To run networks with more than one input, pass a tuple# rather than a single numpy ndarray.print(outputs[0])
局限性
ONNX exporter是一个基于跟踪的exporter,这意味着它在执行您的模型时运行一次,并导出在运行期间实际运行的操作。这意味着如果你的模型是动态的,例如,根据输入数据改变操作行为,则export是不准确的。类似地,跟踪可能只对特定的输入大小有效(这就是为什么我们需要对跟踪进行显式输入的原因之一)。我们建议检查模型跟踪,并确保跟踪的操作符看起来是合理的。
PyTorch和Caffe2通常有一些操作符的结果存在数值差异。根据模型结构,这些差异可以忽略不计,但也可能导致行为产生重大差异(特别对于那些未经训练的模型)。在未来的版本中,我们打算让Caffe2可以直接调用Torch中的一些操作,使得在一些关注精度的任务中,可以帮助研究人员缓和这些差异,同时也会记录下这些差异。
支持的操作符
ONNX支持下面的操作符:
- add (nonzero alpha not supported)
- sub (nonzero alpha not supported)
- mul
- div
- cat
- mm
- addmm
- neg
- tanh
- sigmoid
- mean
- t
- expand (only when used before a broadcasting ONNX operator; e.g., add)
- transpose
- view
- split
- squeeze
- prelu (single weight shared among input channels not supported)
- threshold (non-zero threshold/non-zero value not supported)
- leaky_relu
- glu
- softmax
- avg_pool2d (ceil_mode not supported)
- log_softmax
- unfold (experimental support with ATen-Caffe2 integration)
- elu
- Conv
- BatchNorm
- MaxPool1d (ceil_mode not supported)
- MaxPool2d (ceil_mode not supported)
- MaxPool3d (ceil_mode not supported)
- Embedding (no optional arguments supported)
- RNN
- ConstantPadNd
- Dropout
- FeatureDropout (training mode not supported)
- Index (constant integer and tuple indices supported)
- Negate
上面操作符集合足以导出以下模型:
- AlexNet
- DCGAN
- DenseNet
- Inception (warning: this model is highly sensitive to changes in operator implementation)
- ResNet
- SuperResolution
- VGG
- word_language_model
指定操作符定义的接口是高度实验性和无文档的;喜欢尝鲜的用户请注意,APIs可能会在未来的接口中发生变化
函数
torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None)
将一个模型导出到ONNX格式。该exporter会运行一次你的模型,以便于记录模型的执行轨迹,并将其导出;目前,exporter还不支持动态模型(例如,RNNs)。
另请参阅:onnx-export
参数:
- model(torch.nn.Module)-要被导出的模型
- args(参数的集合)-模型的输入,例如,这种model(*args)方式是对模型的有效调用。任何非Variable参数都将硬编码到导出的模型中;任何Variable参数都将成为导出的模型的输入,并按照他们在args中出现的顺序输入。如果args是一个Variable,这等价于用包含这个Variable的1-ary元组调用它。(注意:现在不支持向模型传递关键字参数。)
- f-一个类文件的对象(必须实现文件描述符的返回)或一个包含文件名字符串。一个二进制Protobuf将会写入这个文件中。
- export_params(bool,default True)-如果指定,所有参数都会被导出。如果你只想导出一个未训练的模型,就将此参数设置为False。在这种情况下,导出的模型将首先把所有parameters作为参arguments,顺序由
model.state_dict().values()
指定。 - verbose(bool,default False)-如果指定,将会输出被导出的轨迹的调试描述。
- training(bool,default False)-导出训练模型下的模型。目前,ONNX只面向推断模型的导出,所以一般不需要将该项设置为True。
- input_names(list of strings, default empty list)-按顺序分配名称到图中的输入节点。
- output_names(list of strings, default empty list)-按顺序分配名称到图中的输出节点。
- PyTorch学习总结(三)——ONNX
- PyTorch学习—PyTorch是什么?
- PyTorch学习总结(四)——Utilities
- PyTorch学习系列(三)——构建神经网络
- 莫烦PyTorch学习笔记(三)——分类
- PyTorch学习总结(一)——查看模型中间结果
- PyTorch学习总结(五)——torch.nn
- PyTorch学习总结(六)——Tensor实现
- PyTorch学习总结(七)——自动求导机制
- pytorch学习总结
- PyTorch学习3—神经网络
- Pytorch学习笔记(三)
- 基于PyTorch的深度学习入门教程(三)——自动梯度
- PyTorch学习总结(二)——基于torch.utils.ffi的自定义C扩展
- pytorch学习笔记(三):自动求导
- PyTorch基本用法(三)——激活函数
- 基于PyTorch的深度学习入门教程(一)——PyTorch安装和配置
- 基于PyTorch的深度学习入门教程(七)——PyTorch重点综合实践
- 数据仓库分层
- 关于NMDS的一知半解
- java中静态初始化块,实例初始化块,构造函数区别
- 【C#】LINQ使用
- 一个小工具类
- PyTorch学习总结(三)——ONNX
- redis的主从复制,读写分离,主从切换
- vue.js移动端app实战3:从一个购物车入门vuex
- C语言 内存管理--指针的函数传递
- 杭电ACM OJ 1040 As Easy As A+B 水(快速排序小变形 或者 维护当前值输出)
- tensorflow中conv2d卷积测试
- 【机器学习圈子里的裙带关系】学术“朋友圈”罪与罚
- 学习记录4
- 用端口复用技术解决tcp连接下服务器主动关闭连接后不能立即重启的问题