ONNX demo

来源:互联网 发布:hl系统线切割怎么编程 编辑:程序博客网 时间:2024/05/19 04:03

ONNX是facebook AI部门那帮人搞出来的东西,可以方便的把pytorch定义训练好的模型转换到caffe2,然后就可以进行部署,尤其是可以部署到移动端。想想,刚刚训练好的pytorch马上就可以部署到android上,是不是很激动~

import ioimport numpy as npfrom torch import nnfrom torch.autograd import Variableimport torch.utils.model_zoo as model_zooimport torch.onnx
# model definitionimport torch.nn as nnimport torch.nn.init as initclass SRnet(nn.Module):    def __init__(self, upscale_factor, inplace=False):        super(SRnet, self).__init__()        self.relu = nn.ReLU(inplace=True)        self.conv1 = nn.Conv2d(1, 64, (5,5),(1,1),(2,2))        self.conv2 = nn.Conv2d(64, 64, (3,3),(1,1),(1,1))        self.conv3 = nn.Conv2d(64, 32, (3,3),(1,1),(1,1))        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3,3),(1,1),(1,1))        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)        self._initialize_weights()    def forward(self, x):        x = self.relu(self.conv1(x))        x = self.relu(self.conv2(x))        x = self.relu(self.conv3(x))        x = self.pixel_shuffle(self.conv4(x))        return x    def _initialize_weights(self):        init.orthogonal(self.conv1.weight, init.calculate_gain('relu'))        init.orthogonal(self.conv2.weight, init.calculate_gain('relu'))        init.orthogonal(self.conv3.weight, init.calculate_gain('relu'))        init.orthogonal(self.conv4.weight)torch_model = SRnet(upscale_factor=3)
# load pretrained modelmap_location = lambda storage, loc : storage # load to cpustate_dict = torch.load('sr.pth', map_location=map_location)torch_model.load_state_dict(state_dict)torch_model.train(False)
SRnet (  (relu): ReLU (inplace)  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))  (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))  (pixel_shuffle): PixelShuffle (upscale_factor=3))

上面这些内容,我们就完成了pytorch端的工作,接下来就要把它导入到caffe2中。这个工作叫做“tracing”,具体实现的方式是,提供一个x,让x把整个网络forward一遍,这个过程会记录用到了哪些torch提供的operator。

batch_size = 1x = Variable(torch.randn(batch_size, 1, 244, 244), requires_grad=True)torch_out = torch.onnx._export(torch_model,                               x,                                "super_resolution.onnx",                               export_params=True)

torch_out输出没有什么特殊的用途,不过可以用来验证,pytorch和caffe2得到相同的结果,导出的模型存为文件,“super_resolution.onnx”。

import onnximport onnx_caffe2.backend# graph is a python protobuf object# for different export dl platform, caffe2, cntk, mxnet, tf# they all use protobuf objectgraph = onnx.load("super_resolution.onnx")prepared_backend = onnx_caffe2.backend.prepare(graph)img_input = {graph.input[0]: x.data.numpy()}c2_out = prepared_backend.run(img_input)[0]np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)

到这里呢,我们成功地把pytorch定义的模型以及训练的参数,在caffe2框架下跑起来的。有没有觉得很6呢?

在cpp_caffe2下运行

cpp使用官方提供的speed_benchmark.cc这个例程。我们先生成这个cpp代码所需要的模型的文件。

c2_workspace = prepared_backend.workspacec2_graph = prepared_backend.predict_netfrom caffe2.python.predictor import mobile_exporterinit_net, predict_net = mobile_exporter.Export(c2_workspace, c2_graph, c2_graph.external_input)with open('init_net.pd', 'wb') as f:    f.write(init_net.SerializeToString())with open('predict_net.pb', 'wb') as f:    f.write(predict_net.SerializeToString())

可以看到,文件夹下面生成了init_net.pbpredict_net.pb,第一个文件是模型的参数文件,第二个文件是模型的定义文件。为什么这样呢?把模型的定义存为文件,这样模型的文件就是和平台无关了,pytorch和caffe2都可以使用这个文件,python和cpp代码都能使用这个文件,ubuntu和android也都能使用这个文件。

# Run on caffe2_pythonfrom caffe2.proto import caffe2_pb2from caffe2.python import core, net_drawer, net_printer, visualize, workspace,utilsimport numpy as npimport osimport subprocessfrom PIL import Imagefrom skimage import io, transform
img = Image.open('./cat_244x244.jpg')img_ycbcr = img.convert('YCbCr')img_y, img_cb, img_cr = img_ycbcr.split()workspace.RunNetOnce(init_net)workspace.RunNetOnce(predict_net)print(net_printer.to_string(predict_net))
# net: torch-jit-export11 = Conv(1, 2, kernels=[5L, 5L], strides=[1L, 1L], pads=[2L, 2L, 2L, 2L], dilations=[1L, 1L], group=1)12 = Add(11, 3, broadcast=1, axis=1)13 = Relu(12)15 = Conv(13, 4, kernels=[3L, 3L], strides=[1L, 1L], pads=[1L, 1L, 1L, 1L], dilations=[1L, 1L], group=1)16 = Add(15, 5, broadcast=1, axis=1)17 = Relu(16)19 = Conv(17, 6, kernels=[3L, 3L], strides=[1L, 1L], pads=[1L, 1L, 1L, 1L], dilations=[1L, 1L], group=1)20 = Add(19, 7, broadcast=1, axis=1)21 = Relu(20)23 = Conv(21, 8, kernels=[3L, 3L], strides=[1L, 1L], pads=[1L, 1L, 1L, 1L], dilations=[1L, 1L], group=1)24 = Add(23, 9, broadcast=1, axis=1)25, _onnx_dummy1 = Reshape(24, shape=[1L, 1L, 3L, 3L, 244L, 244L])26 = Transpose(25, axes=[0L, 1L, 4L, 2L, 5L, 3L])27, _onnx_dummy2 = Reshape(26, shape=[1L, 1L, 732L, 732L])
# feed inputworkspace.FeedBlob('1', np.array(img_y)[np.newaxis, np.newaxis, :,:].astype(np.float32))# forward networkspace.RunNetOnce(predict_net)# fetch outputimg_out = workspace.FetchBlob('27')
# save output to imageimg_out_y = Image.fromarray(np.uint8(img_out[0,0]).clip(0,255), mode='L')final_img = Image.merge(    'YCbCr', [        img_out_y,        img_cb.resize(img_out_y.size, Image.BICUBIC),        img_cr.resize(img_out_y.size, Image.BICUBIC),    ]).convert('RGB')final_img.save('./cat_superres.jpg')
# prepare input blobwith open('input.blobproto', 'wb') as f:    f.write(workspace.SerializeBlob('1'))

编译cpp代码

我们需要编译cpp的代码,使用如下编译命令:

CAFFE2_ROOT=$HOME/src/caffe2g++ speed_benchmark.cc -o demo -std=c++11 \    -I $CAFFE2_ROOT/third_party/eigen \    -lCaffe2_CPU \    -lglog \    -lgflags \    -lprotobuf \    -lpthread \    -llmdb \    -lleveldb \    -lopencv_core \    -lopencv_highgui \    -lopencv_imgproc 

能够使用这条命令的前提是,caffe2安装到了/usr/local下,使用了sudo make install进行安装。

运行cpp程序:

./demo --init_net init_net.pd --net predict_net.pb --input 1 --input_file input.blobproto --output_folder . --output 27 --iter 1

等一下,我们先看一眼speed_benchmark.cc

#include <string>#include "caffe2/core/init.h"#include "caffe2/core/operator.h"#include "caffe2/proto/caffe2.pb.h"#include "caffe2/utils/proto_utils.h"#include "caffe2/utils/string_utils.h"#include "caffe2/core/logging.h"// 定义argsCAFFE2_DEFINE_string(net, "", "The given net to benchmark.");CAFFE2_DEFINE_string(init_net, "",                     "The given net to initialize any parameters.");CAFFE2_DEFINE_string(input, "",                     "Input that is needed for running the network. If "                     "multiple input needed, use comma separated string.");CAFFE2_DEFINE_string(input_file, "",                     "Input file that contain the serialized protobuf for "                     "the input blobs. If multiple input needed, use comma "                     "separated string. Must have the same number of items "                     "as input does.");CAFFE2_DEFINE_string(input_dims, "",                     "Alternate to input_files, if all inputs are simple "                     "float TensorCPUs, specify the dimension using comma "                     "separated numbers. If multiple input needed, use "                     "semicolon to separate the dimension of different "                     "tensors.");CAFFE2_DEFINE_string(output, "",                     "Output that should be dumped after the execution "                     "finishes. If multiple outputs are needed, use comma "                     "separated string. If you want to dump everything, pass "                     "'*' as the output value.");CAFFE2_DEFINE_string(output_folder, "",                     "The folder that the output should be written to. This "                     "folder must already exist in the file system.");CAFFE2_DEFINE_int(warmup, 0, "The number of iterations to warm up.");CAFFE2_DEFINE_int(iter, 10, "The number of iterations to run.");CAFFE2_DEFINE_bool(run_individual, false, "Whether to benchmark individual operators.");using std::string;using std::unique_ptr;using std::vector;int main(int argc, char** argv) {  caffe2::GlobalInit(&argc, &argv);  unique_ptr<caffe2::Workspace> workspace(new caffe2::Workspace());  // 读取模型参数到工作空间  caffe2::NetDef net_def;  CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_init_net, &net_def));  CAFFE_ENFORCE(workspace->RunNetOnce(net_def));  // 加载输入数据,提供两种方式,--input和--input_dims  if (caffe2::FLAGS_input.size()) {    vector<string> input_names = caffe2::split(',', caffe2::FLAGS_input);    if (caffe2::FLAGS_input_file.size()) {      vector<string> input_files = caffe2::split(',', caffe2::FLAGS_input_file);      CAFFE_ENFORCE_EQ(          input_names.size(), input_files.size(),          "Input name and file should have the same number.");      for (int i = 0; i < input_names.size(); ++i) {        caffe2::BlobProto blob_proto;        CAFFE_ENFORCE(caffe2::ReadProtoFromFile(input_files[i], &blob_proto));        workspace->CreateBlob(input_names[i])->Deserialize(blob_proto);      }    } else if (caffe2::FLAGS_input_dims.size()) {      vector<string> input_dims_list = caffe2::split(';', caffe2::FLAGS_input_dims);      CAFFE_ENFORCE_EQ(          input_names.size(), input_dims_list.size(),          "Input name and dims should have the same number of items.");      for (int i = 0; i < input_names.size(); ++i) {        vector<string> input_dims_str = caffe2::split(',', input_dims_list[i]);        vector<int> input_dims;        for (const string& s : input_dims_str) {          input_dims.push_back(caffe2::stoi(s));        }        caffe2::TensorCPU* tensor =            workspace->GetBlob(input_names[i])->GetMutable<caffe2::TensorCPU>();        tensor->Resize(input_dims);        tensor->mutable_data<float>();      }    } else {      CAFFE_THROW("You requested input tensors, but neither input_file nor "                  "input_dims is set.");    }  }  // 加载模型定义文件,创建模型,  CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_net, &net_def));  caffe2::NetBase* net = workspace->CreateNet(net_def);  CHECK_NOTNULL(net);  net->TEST_Benchmark(      caffe2::FLAGS_warmup,      caffe2::FLAGS_iter,      caffe2::FLAGS_run_individual);  // 获得输出  string output_prefix = caffe2::FLAGS_output_folder.size()      ? caffe2::FLAGS_output_folder + "/"      : "";  if (caffe2::FLAGS_output.size()) {    vector<string> output_names = caffe2::split(',', caffe2::FLAGS_output);    if (caffe2::FLAGS_output == "*") {      output_names = workspace->Blobs();    }    for (const string& name : output_names) {      CAFFE_ENFORCE(          workspace->HasBlob(name),          "You requested a non-existing blob: ",          name);      string serialized = workspace->GetBlob(name)->Serialize(name);      string output_filename = output_prefix + name;      caffe2::WriteStringToFile(serialized, output_filename.c_str());    }  }  return 0;}

程序运行的结果是,生成了一个27文件,我们用python把这个文件转换为jpg。

blob_proto = caffe2_pb2.BlobProto()blob_proto.ParseFromString(open('./27_mobile').read())img_out = utils.Caffe2TensorToNumpyArray(blob_proto.tensor)img_out_y = Image.fromarray(np.uint8((img_out[0,0]).clip(0,255)), mode='L')final_img = Image.merge(    "YCbCr", [        img_out_y,        img_cb.resize(img_out_y.size, Image.BICUBIC),        img_cr.resize(img_out_y.size, Image.BICUBIC),    ]).convert('RGB')final_img.save('./cat_superres_mobile.jpg')

关于放到android上运行

按照原教程的做法,可以顺利运行。这里写几点值得注意的点:
- android编译的可执行程序是静态编译,以便方便地在手机设备上运行。
- android程序并不是一定要用java写,这个例子便是用cpp写的,这也表明,android-cmake这个项目真得很方便。
- 不过,如果仅仅在android上运行控制台程序,那也太不爽了,那我还要这嵌入式设备做啥子呢?所以,最后一定要是放到一个有界面的app中运行,这个就要用原生的java写了。
- 关于android环境配置,最快的方式莫过于,装一个android studio.

原创粉丝点击