torch学习笔记3.2:实现自定义模块(cpu)

来源:互联网 发布:findit软件序列号 编辑:程序博客网 时间:2024/06/05 20:14

在使用torch时,如果想自己实现一个层,则可以按照《torch学习笔记1:实现自定义层》 中的方法来实现。但是如果想要实现一个比较复杂的网络,往往需要自己实现多个层(或类),并且有时可能需要重写其他模块中已有的函数来达到自己的目的,如果还是在nn模块中添加,会比较混乱,并且不利于本地git仓库统一管理,这个时候,我们可以自己实现一个像nn一样的模块,在代码中使用时 require即可。

我们来实现一个名为nxn的自定义模块,以及它的cuda版本cunxn模块,其中包含一个自定义的Hello类(lua实现),ReLU类(分别用CPU和GPU实现)。

由于篇幅原因,这里把torch自定义模块的lua实现,cpu实现,gpu实现分别写一篇文章,本文介绍cpu实现的ReLU类。

1 总目录结构和 2 使用说明 在 《torch学习笔记3.1:实现自定义模块(lua)》

3 文件说明

ReLU.lua

local ReLU = torch.class('nxn.ReLU')function ReLU:__init(gpucompatible)   self.gpucompatible=gpucompatible   if self.gpucompatible then      self.gradInput=torch.CudaTensor()      self.output=torch.CudaTensor()   else      self.gradInput=torch.Tensor()      self.output=torch.Tensor()   end   self.outputSave=self.output   self.gradInputSave=self.gradInputendfunction ReLU:updateOutput(input)   -- 调用cpp实现的ReLU函数   return input.nxn.ReLU_updateOutput(self, input)endfunction ReLU:updateGradInput(input, gradOutput)   -- 调用cpp实现的ReLU函数   return input.nxn.ReLU_updateGradInput(self, input, gradOutput)endfunction ReLU:getDisposableTensors()   return {self.output, self.gradInput, self.gradInputSave, self.outputSave}end

ReLU.c

内容如下:

#ifndef TH_GENERIC_FILE#define TH_GENERIC_FILE "generic/ReLU.c"#elsestatic int nxn_(ReLU_updateOutput)(lua_State *L){  printf("CPU version of ReLU updateOutput function\n");  THTensor *input = luaT_checkudata(L, 2, torch_Tensor);  THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);  THTensor_(resizeAs)(output, input);  TH_TENSOR_APPLY2(real, output, real, input,         \                   *output_data = *input_data > 0 ? *input_data : 0;)  return 1;}static int nxn_(ReLU_updateGradInput)(lua_State *L){  printf("CPU version of ReLU updateGradInput function\n");  THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);  THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);  THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);  THTensor_(resizeAs)(gradInput, output);  TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output,     \                   *gradInput_data = *gradOutput_data * (*output_data > 0 ? 1 : 0););  return 1;}static const struct luaL_Reg nxn_(ReLU__) [] = {  {"ReLU_updateOutput", nxn_(ReLU_updateOutput)},  {"ReLU_updateGradInput", nxn_(ReLU_updateGradInput)},  {NULL, NULL}};static void nxn_(ReLU_init)(lua_State *L){  luaT_pushmetatable(L, torch_Tensor);  luaT_registeratname(L, nxn_(ReLU__), "nxn");  lua_pop(L,1);}#endif

init.c

在编译安装模块时CMakeLists.txt根据init.c找类文件:

#include "TH.h"#include "luaT.h"#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME)#define torch_Tensor TH_CONCAT_STRING_3(torch.,Real,Tensor)#define nxn_(NAME) TH_CONCAT_3(nxn_, Real, NAME)#include "generic/ReLU.c"#include "THGenerateFloatTypes.h"LUA_EXTERNC DLL_EXPORT int luaopen_libnxn(lua_State *L);// 把cpp实现编译到libnxnint luaopen_libnxn(lua_State *L){  lua_newtable(L);  lua_pushvalue(L, -1);  lua_setfield(L, LUA_GLOBALSINDEX, "nxn");  nxn_FloatReLU_init(L);  nxn_DoubleReLU_init(L);  return 1;}
0 0
原创粉丝点击