torch学习笔记3.1:实现自定义模块(lua)

来源:互联网 发布:阿里云vps购买 编辑:程序博客网 时间:2024/05/16 11:50

在使用torch时,如果想自己实现一个层,则可以按照《torch学习笔记1:实现自定义层》 中的方法来实现。但是如果想要实现一个比较复杂的网络,往往需要自己实现多个层(或类),并且有时可能需要重写其他模块中已有的函数来达到自己的目的,如果还是在nn模块中添加,会比较混乱,并且不利于本地git仓库统一管理,这个时候,我们可以自己实现一个像nn一样的模块,在代码中使用时 require即可。

我们来实现一个名为nxn的自定义模块,以及它的cuda版本cunxn模块,其中包含一个自定义的Hello类(lua实现),ReLU类(分别用CPU和GPU实现)。

由于篇幅原因,这里把torch自定义模块的lua实现,cpu实现,gpu实现分别写一篇文章,本文先介绍lua实现的Hello类。

1 总目录结构

模板源代码可在我的资源中下载。

.../myproj/      |----scripts/           |---- demo.lua      |----nxn/           |---- CMakeLists.txt           |---- nxn-scm-1.rockspec           |---- init.lua           |---- init.c           |---- ReLU.lua           |---- Hello.lua           |---- generic/                 |---- ReLU.c           |---- test/                 |---- test.lua      |----cunxn/           |---- CMakeLists.txt           |---- cunxn-scm-1.rockspec           |---- init.lua           |---- init.cu              |---- ReLU.cu           |---- test/                 |---- test.lua                

2 使用

  1. 成功安装了torch。
  2. 在nxn目录下运行
luarocks make nxn-scm-1.rockspec
  1. 在cunxn目录下运行
luarocks make cunxn-scm-1.rockspec
  1. 在scripts目录下运行
th demo.lua
  1. 输出
    result

3 文件说明

demo.lua

是使用自定义类的示例代码。

require 'cunxn'local module = nxn.Hello()module:updateOutput()input = torch.rand(3,3)print(input)local module = nxn.ReLU(false)output = module:updateOutput(input)print(output)cutorch.setDevice(2)input = input:cuda()print(input)local module = nxn.ReLU(true)output = module:updateOutput(input)print(output)

CMakeLists.txt

一般和nn之类的模块没有太大区别,仿照着写即可,需要注意的是以下几句:

......# 编译时从init.c找cpu实现的代码文件SET(src init.c) # 指定要编译的lua文件FILE(GLOB luasrc *.lua)SET(luasrc ${luasrc} test/test.lua)# 把cpp和lua文件加入模块nxnADD_TORCH_PACKAGE(nxn "${src}" "${luasrc}")# 链接lua库TARGET_LINK_LIBRARIES(nxn luaT TH)......

nxn-scm-1.rockspec

注意dependencies里面还可以添加已有模块,比如nn,cunn,格式如下:

......dependencies = {   "torch >= 7.0",   "cunn",   "nn"}......

init.lua

内容如下,要include自定义类的lua文件,以及这里把cpp实现编译成了一个lib,也要添加进来。

require('torch')require('libnxn')include('ReLU.lua')include('Hello.lua')

Hello.lua

自定义类的文件,该类由lua实现,这里提供一个简单的模板。

local Hello = torch.class('nxn.Hello')function Hello:__init()endfunction Hello:updateOutput()   print("hello in updateOutput")endfunction Hello:updateGradInput(input, gradOutput)   print("hello in updateGradInput")end

未完,后续说明见 CPU实现,GPU实现。

1 0
原创粉丝点击