ONNX demo
来源:互联网 发布:hl系统线切割怎么编程 编辑:程序博客网 时间:2024/05/19 04:03
ONNX是facebook AI部门那帮人搞出来的东西,可以方便的把pytorch定义训练好的模型转换到caffe2,然后就可以进行部署,尤其是可以部署到移动端。想想,刚刚训练好的pytorch马上就可以部署到android上,是不是很激动~
import ioimport numpy as npfrom torch import nnfrom torch.autograd import Variableimport torch.utils.model_zoo as model_zooimport torch.onnx
# model definitionimport torch.nn as nnimport torch.nn.init as initclass SRnet(nn.Module): def __init__(self, upscale_factor, inplace=False): super(SRnet, self).__init__() self.relu = nn.ReLU(inplace=True) self.conv1 = nn.Conv2d(1, 64, (5,5),(1,1),(2,2)) self.conv2 = nn.Conv2d(64, 64, (3,3),(1,1),(1,1)) self.conv3 = nn.Conv2d(64, 32, (3,3),(1,1),(1,1)) self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3,3),(1,1),(1,1)) self.pixel_shuffle = nn.PixelShuffle(upscale_factor) self._initialize_weights() def forward(self, x): x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) x = self.relu(self.conv3(x)) x = self.pixel_shuffle(self.conv4(x)) return x def _initialize_weights(self): init.orthogonal(self.conv1.weight, init.calculate_gain('relu')) init.orthogonal(self.conv2.weight, init.calculate_gain('relu')) init.orthogonal(self.conv3.weight, init.calculate_gain('relu')) init.orthogonal(self.conv4.weight)torch_model = SRnet(upscale_factor=3)
# load pretrained modelmap_location = lambda storage, loc : storage # load to cpustate_dict = torch.load('sr.pth', map_location=map_location)torch_model.load_state_dict(state_dict)torch_model.train(False)
SRnet ( (relu): ReLU (inplace) (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (pixel_shuffle): PixelShuffle (upscale_factor=3))
上面这些内容,我们就完成了pytorch端的工作,接下来就要把它导入到caffe2中。这个工作叫做“tracing”,具体实现的方式是,提供一个x,让x把整个网络forward一遍,这个过程会记录用到了哪些torch提供的operator。
batch_size = 1x = Variable(torch.randn(batch_size, 1, 244, 244), requires_grad=True)torch_out = torch.onnx._export(torch_model, x, "super_resolution.onnx", export_params=True)
torch_out输出没有什么特殊的用途,不过可以用来验证,pytorch和caffe2得到相同的结果,导出的模型存为文件,“super_resolution.onnx”。
import onnximport onnx_caffe2.backend# graph is a python protobuf object# for different export dl platform, caffe2, cntk, mxnet, tf# they all use protobuf objectgraph = onnx.load("super_resolution.onnx")prepared_backend = onnx_caffe2.backend.prepare(graph)img_input = {graph.input[0]: x.data.numpy()}c2_out = prepared_backend.run(img_input)[0]np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)
到这里呢,我们成功地把pytorch定义的模型以及训练的参数,在caffe2框架下跑起来的。有没有觉得很6呢?
在cpp_caffe2下运行
cpp使用官方提供的speed_benchmark.cc这个例程。我们先生成这个cpp代码所需要的模型的文件。
c2_workspace = prepared_backend.workspacec2_graph = prepared_backend.predict_netfrom caffe2.python.predictor import mobile_exporterinit_net, predict_net = mobile_exporter.Export(c2_workspace, c2_graph, c2_graph.external_input)with open('init_net.pd', 'wb') as f: f.write(init_net.SerializeToString())with open('predict_net.pb', 'wb') as f: f.write(predict_net.SerializeToString())
可以看到,文件夹下面生成了init_net.pb和predict_net.pb,第一个文件是模型的参数文件,第二个文件是模型的定义文件。为什么这样呢?把模型的定义存为文件,这样模型的文件就是和平台无关了,pytorch和caffe2都可以使用这个文件,python和cpp代码都能使用这个文件,ubuntu和android也都能使用这个文件。
# Run on caffe2_pythonfrom caffe2.proto import caffe2_pb2from caffe2.python import core, net_drawer, net_printer, visualize, workspace,utilsimport numpy as npimport osimport subprocessfrom PIL import Imagefrom skimage import io, transform
img = Image.open('./cat_244x244.jpg')img_ycbcr = img.convert('YCbCr')img_y, img_cb, img_cr = img_ycbcr.split()workspace.RunNetOnce(init_net)workspace.RunNetOnce(predict_net)print(net_printer.to_string(predict_net))
# net: torch-jit-export11 = Conv(1, 2, kernels=[5L, 5L], strides=[1L, 1L], pads=[2L, 2L, 2L, 2L], dilations=[1L, 1L], group=1)12 = Add(11, 3, broadcast=1, axis=1)13 = Relu(12)15 = Conv(13, 4, kernels=[3L, 3L], strides=[1L, 1L], pads=[1L, 1L, 1L, 1L], dilations=[1L, 1L], group=1)16 = Add(15, 5, broadcast=1, axis=1)17 = Relu(16)19 = Conv(17, 6, kernels=[3L, 3L], strides=[1L, 1L], pads=[1L, 1L, 1L, 1L], dilations=[1L, 1L], group=1)20 = Add(19, 7, broadcast=1, axis=1)21 = Relu(20)23 = Conv(21, 8, kernels=[3L, 3L], strides=[1L, 1L], pads=[1L, 1L, 1L, 1L], dilations=[1L, 1L], group=1)24 = Add(23, 9, broadcast=1, axis=1)25, _onnx_dummy1 = Reshape(24, shape=[1L, 1L, 3L, 3L, 244L, 244L])26 = Transpose(25, axes=[0L, 1L, 4L, 2L, 5L, 3L])27, _onnx_dummy2 = Reshape(26, shape=[1L, 1L, 732L, 732L])
# feed inputworkspace.FeedBlob('1', np.array(img_y)[np.newaxis, np.newaxis, :,:].astype(np.float32))# forward networkspace.RunNetOnce(predict_net)# fetch outputimg_out = workspace.FetchBlob('27')
# save output to imageimg_out_y = Image.fromarray(np.uint8(img_out[0,0]).clip(0,255), mode='L')final_img = Image.merge( 'YCbCr', [ img_out_y, img_cb.resize(img_out_y.size, Image.BICUBIC), img_cr.resize(img_out_y.size, Image.BICUBIC), ]).convert('RGB')final_img.save('./cat_superres.jpg')
# prepare input blobwith open('input.blobproto', 'wb') as f: f.write(workspace.SerializeBlob('1'))
编译cpp代码
我们需要编译cpp的代码,使用如下编译命令:
CAFFE2_ROOT=$HOME/src/caffe2g++ speed_benchmark.cc -o demo -std=c++11 \ -I $CAFFE2_ROOT/third_party/eigen \ -lCaffe2_CPU \ -lglog \ -lgflags \ -lprotobuf \ -lpthread \ -llmdb \ -lleveldb \ -lopencv_core \ -lopencv_highgui \ -lopencv_imgproc
能够使用这条命令的前提是,caffe2安装到了/usr/local
下,使用了sudo make install
进行安装。
运行cpp程序:
./demo --init_net init_net.pd --net predict_net.pb --input 1 --input_file input.blobproto --output_folder . --output 27 --iter 1
等一下,我们先看一眼speed_benchmark.cc
#include <string>#include "caffe2/core/init.h"#include "caffe2/core/operator.h"#include "caffe2/proto/caffe2.pb.h"#include "caffe2/utils/proto_utils.h"#include "caffe2/utils/string_utils.h"#include "caffe2/core/logging.h"// 定义argsCAFFE2_DEFINE_string(net, "", "The given net to benchmark.");CAFFE2_DEFINE_string(init_net, "", "The given net to initialize any parameters.");CAFFE2_DEFINE_string(input, "", "Input that is needed for running the network. If " "multiple input needed, use comma separated string.");CAFFE2_DEFINE_string(input_file, "", "Input file that contain the serialized protobuf for " "the input blobs. If multiple input needed, use comma " "separated string. Must have the same number of items " "as input does.");CAFFE2_DEFINE_string(input_dims, "", "Alternate to input_files, if all inputs are simple " "float TensorCPUs, specify the dimension using comma " "separated numbers. If multiple input needed, use " "semicolon to separate the dimension of different " "tensors.");CAFFE2_DEFINE_string(output, "", "Output that should be dumped after the execution " "finishes. If multiple outputs are needed, use comma " "separated string. If you want to dump everything, pass " "'*' as the output value.");CAFFE2_DEFINE_string(output_folder, "", "The folder that the output should be written to. This " "folder must already exist in the file system.");CAFFE2_DEFINE_int(warmup, 0, "The number of iterations to warm up.");CAFFE2_DEFINE_int(iter, 10, "The number of iterations to run.");CAFFE2_DEFINE_bool(run_individual, false, "Whether to benchmark individual operators.");using std::string;using std::unique_ptr;using std::vector;int main(int argc, char** argv) { caffe2::GlobalInit(&argc, &argv); unique_ptr<caffe2::Workspace> workspace(new caffe2::Workspace()); // 读取模型参数到工作空间 caffe2::NetDef net_def; CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_init_net, &net_def)); CAFFE_ENFORCE(workspace->RunNetOnce(net_def)); // 加载输入数据,提供两种方式,--input和--input_dims if (caffe2::FLAGS_input.size()) { vector<string> input_names = caffe2::split(',', caffe2::FLAGS_input); if (caffe2::FLAGS_input_file.size()) { vector<string> input_files = caffe2::split(',', caffe2::FLAGS_input_file); CAFFE_ENFORCE_EQ( input_names.size(), input_files.size(), "Input name and file should have the same number."); for (int i = 0; i < input_names.size(); ++i) { caffe2::BlobProto blob_proto; CAFFE_ENFORCE(caffe2::ReadProtoFromFile(input_files[i], &blob_proto)); workspace->CreateBlob(input_names[i])->Deserialize(blob_proto); } } else if (caffe2::FLAGS_input_dims.size()) { vector<string> input_dims_list = caffe2::split(';', caffe2::FLAGS_input_dims); CAFFE_ENFORCE_EQ( input_names.size(), input_dims_list.size(), "Input name and dims should have the same number of items."); for (int i = 0; i < input_names.size(); ++i) { vector<string> input_dims_str = caffe2::split(',', input_dims_list[i]); vector<int> input_dims; for (const string& s : input_dims_str) { input_dims.push_back(caffe2::stoi(s)); } caffe2::TensorCPU* tensor = workspace->GetBlob(input_names[i])->GetMutable<caffe2::TensorCPU>(); tensor->Resize(input_dims); tensor->mutable_data<float>(); } } else { CAFFE_THROW("You requested input tensors, but neither input_file nor " "input_dims is set."); } } // 加载模型定义文件,创建模型, CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_net, &net_def)); caffe2::NetBase* net = workspace->CreateNet(net_def); CHECK_NOTNULL(net); net->TEST_Benchmark( caffe2::FLAGS_warmup, caffe2::FLAGS_iter, caffe2::FLAGS_run_individual); // 获得输出 string output_prefix = caffe2::FLAGS_output_folder.size() ? caffe2::FLAGS_output_folder + "/" : ""; if (caffe2::FLAGS_output.size()) { vector<string> output_names = caffe2::split(',', caffe2::FLAGS_output); if (caffe2::FLAGS_output == "*") { output_names = workspace->Blobs(); } for (const string& name : output_names) { CAFFE_ENFORCE( workspace->HasBlob(name), "You requested a non-existing blob: ", name); string serialized = workspace->GetBlob(name)->Serialize(name); string output_filename = output_prefix + name; caffe2::WriteStringToFile(serialized, output_filename.c_str()); } } return 0;}
程序运行的结果是,生成了一个27
文件,我们用python把这个文件转换为jpg。
blob_proto = caffe2_pb2.BlobProto()blob_proto.ParseFromString(open('./27_mobile').read())img_out = utils.Caffe2TensorToNumpyArray(blob_proto.tensor)img_out_y = Image.fromarray(np.uint8((img_out[0,0]).clip(0,255)), mode='L')final_img = Image.merge( "YCbCr", [ img_out_y, img_cb.resize(img_out_y.size, Image.BICUBIC), img_cr.resize(img_out_y.size, Image.BICUBIC), ]).convert('RGB')final_img.save('./cat_superres_mobile.jpg')
关于放到android上运行
按照原教程的做法,可以顺利运行。这里写几点值得注意的点:
- android编译的可执行程序是静态编译,以便方便地在手机设备上运行。
- android程序并不是一定要用java写,这个例子便是用cpp写的,这也表明,android-cmake
这个项目真得很方便。
- 不过,如果仅仅在android上运行控制台程序,那也太不爽了,那我还要这嵌入式设备做啥子呢?所以,最后一定要是放到一个有界面的app中运行,这个就要用原生的java
写了。
- 关于android环境配置,最快的方式莫过于,装一个android studio
.
- ONNX demo
- onnx 使用初体验
- onnx on OSX
- PyTorch学习总结(三)——ONNX
- AWS 帮助构建 ONNX 开源 AI 平台
- 开源 | 微软、Facebook联手打造AI生态系统ONNX
- Demo
- demo
- demo
- demo
- demo
- demo
- DEMO
- Demo
- DEMO
- demo
- Demo
- demo
- week4-leetcode #6-ZigZag Conversion[Medium]
- tomcat部署
- Python使用网易邮箱发邮件
- Java数据结构详解(八)-Queue接口
- call apply bind
- ONNX demo
- 程序人生-从上帝视角看问题
- 1074. 宇宙无敌加法器(20)
- HDU-4424 Conquer a New Region(并查集)
- 手把手教你创建maven的web3.0项目
- <C++>9.类成员函数的定义
- Redis学习
- jupyter notebook无法正常启动
- JavaScript中怪异现象true和false