PyTorch学习总结(二)——基于torch.utils.ffi的自定义C扩展
来源:互联网 发布:电子相册制作软件知乎 编辑:程序博客网 时间:2024/05/16 09:06
步骤一 准备好你的C代码
首先,你写好你的C函数。
接下来你可以找到一个模块的forward和backward函数的实现,其主要实现输入相加的功能。
在你的.c
文件中,你可以使用#include <TH/TH.h>
和#include <THC/THC.h>
指令来分别包含TH及THC。
ffi工具可以确保编译器在build的过程中找到它们。
/* src/my_lib.c */#include <TH/TH.h>int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,THFloatTensor *output){ if (!THFloatTensor_isSameSizeAs(input1, input2)) return 0; THFloatTensor_resizeAs(output, input1); THFloatTensor_cadd(output, input1, 1.0, input2); return 1;}int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input){ THFloatTensor_resizeAs(grad_input, grad_output); THFloatTensor_fill(grad_input, 1); return 1;}
代码中没有约束条件。如果想要添加约束条件,你得准备一个header文件,它包含了所有希望在python中调用的函数的列表。
然后它会被ffi工具用来生成适当的封装。
/* src/my_lib.h */int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output);int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);
现在,你需要一个简短的文件来生成(build)你的自定义扩展:
# build.pyfrom torch.utils.ffi import create_extensionffi = create_extension(name='_ext.my_lib',headers='src/my_lib.h',sources=['src/my_lib.c'],with_cuda=False)ffi.build()
步骤二 将它包含到你的Python代码中
当你运行完上述指令后,pytorch会创建一个_ext
目录,然后将你的my_lib
库放进去。
包名可以在最终的模块名前面有任意数量的包(数量也可以等于0).如果build成功,你可以像导入常规的python文件一样导入你的扩展。
定义新的函数:
# functions/add.pyimport torchfrom torch.autograd import Functionfrom _ext import my_libclass MyAddFunction(Function): def forward(self, input1, input2): output = torch.FloatTensor() my_lib.my_lib_add_forward(input1, input2, output) return output def backward(self, grad_output): grad_input = torch.FloatTensor() my_lib.my_lib_add_backward(grad_output, grad_input) return grad_input
定义新的模块:
# modules/add.pyfrom torch.nn import Modulefrom functions.add import MyAddFunctionclass MyAddModule(Module): def forward(self, input1, input2): return MyAddFunction()(input1, input2)
在模块中实现嵌套:
# main.pyimport torchimport torch.nn as nnfrom torch.autograd import Variablefrom modules.add import MyAddModuleclass MyNetwork(nn.Module): def __init__(self): super(MyNetwork, self).__init__() self.add = MyAddModule() def forward(self, input1, input2): return self.add(input1, input2)model = MyNetwork()input1, input2 = Variable(torch.randn(5, 5)), Variable(torch.randn(5, 5))print(model(input1, input2))print(input1 + input2)
阅读全文
0 0
- PyTorch学习总结(二)——基于torch.utils.ffi的自定义C扩展
- PyTorch学习总结(五)——torch.nn
- 莫烦PyTorch学习笔记(一)——Torch或Numpy
- PyTorch的学习笔记-torch package
- 基于PyTorch的深度学习入门教程(二)——简单知识
- 基于PyTorch的深度学习入门教程(一)——PyTorch安装和配置
- 基于PyTorch的深度学习入门教程(七)——PyTorch重点综合实践
- pytorch学习笔记(十八):C 语言扩展 pytorch
- PyTorch学习—PyTorch是什么?
- pytorch 学习笔记之编写 C 扩展
- PyTorch基本用法(一)——Numpy,Torch对比
- PyTorch学习总结(三)——ONNX
- PyTorch学习总结(四)——Utilities
- PyTorch(二)——搭建和自定义网络
- PyTorch的concat也就是torch.cat实例
- Torch7学习(二) —— Torch与Matlab的语法对比
- 「Deep Learning」理解Pytorch中的「torch.utils.data」
- 基于PyTorch的深度学习入门教程(三)——自动梯度
- linux下mysql的root密码忘记解决方
- 在ubuntu中安装单机Hadoop(三)
- 坑爹的Android Ble问题记录日志
- 使用 phantomjs 异步爬取 ajax 网页数据
- Lintcode:尾部的零
- PyTorch学习总结(二)——基于torch.utils.ffi的自定义C扩展
- Linux下swoole环境搭建
- Linux如何查看当前占用CPU或内存最多的几个进程
- 关于WEBSERVICE的Connection reset异常
- 小学奥数思维训练题(十一)
- 目标检测标注工具labelImg使用方法
- spring-boot 学习之路
- JCVideoPlayerStandard的视频播放
- 常量指针必须初始化