pytorch model 2 coreml

来源:互联网 发布:手机淘宝 查看评价 编辑:程序博客网 时间:2024/06/04 23:23

需要将训练好的 pytorch model 移植到 ios上,需要转换成 coreml格式。
caffe2 可以在 coreml上使用,一种方式是 pytorch 转换成 caffe2, 再利用 caffe2的跨平台移植到移动端上;
此外,onnx 可以直接转换成 coreml 格式的。

需要转换的模型包含 conv, batchnorm, relu, avgpool2d, dropout, linear.
从源码编译 pytorch才能使用 onnx, 从源码编译才能使用 onnx.checker.编译完后先将pytorch model 转换成 onnx格式的,

import torchfrom model import TuneMobileNet, fcModelfrom torch.autograd import Variablemodel_name = 'mobile_dict.pt'state_dict = torch.load(model_name)cls_number = 17model = TuneMobileNet(cls_number)model.load_state_dict(state_dict)x = Variable(torch.randn(1, 3, 224, 224), requires_grad=True)torch_out = torch.onnx.export(model,                               x,                               'hard_mobile_convert.onnx',                              verbose=True,                               export_params=True)

在 avgpool2d的时候报错,padding size mismatch, 参考 issue [onnx] convered model the AveragePool bug in core-ml #3808, 和 pull requests fix pooling layer padding dim mistmatch bug #7。

解决了之后,再转换成 coreml格式的。

import onnximport onnx_coremlmodel = onnx.load('hard_mobile_convert.onnx')cml = onnx_coreml.convert(model)cml.save('hard_mobile.mlmodel')

接连报三个错误:
1. key error, 这是因为 pytorch 将 linear layer 转换成TransposeGemm两层,参考Why torch.nn.Linear is split into Transpose and Gemm layers in torch.onnx.export()? #3257, 和linear convert error #8, 就没有了transB参数。

if node.attrs["broadcast"] != 1 or node.attrs["transB"] != 1:        raise ValueError(            "Gemm is supported only for inner_product layer"        )key error: u'transB'

手动改为:

if node.attrs["broadcast"] != 1 or (hasattr(node, 'transB') and node.attrs["transB"] != 1):
  1. linear layer 只获取到了 bias, 没有 权重w
    在 onnx_coreml._graph中hard code 进去 w, 结果在onnx_coreml.convert中提示图出错,这是 coremltools自己的 check, debug 到了coremltools导入了libcoremlpython.so的动态库。

只能将整个模型切分成两个部分,featuresclassifier. features转换成 coreml识别的.mlmodel格式,classifier导出权重为 json文件,然后手动加载在 ios中。

原创粉丝点击