Torch7入门续集补充--- nngraph包的使用
来源:互联网 发布:淘宝折800 编辑:程序博客网 时间:2024/05/17 07:35
构建方法
nngraph包在构建更加复杂的网络极其有用。毕竟是有点类似”静态图“了。
简单来说就是以前加网络需要不断add,现在用了nngraph,只要不断”一“就行了。
h1 = - nn.Linear(20,10)h2 = h1 - nn.Tanh() - nn.Linear(10,10) - nn.Tanh() - nn.Linear(10, 1)mlp = nn.gModule({h1}, {h2})
注意点:
1. 刚开始时需要用”-“来初始化。
2. 在nn.gModule中写入两个table,第一个table表示输入节点,第二个是输出节点。
当然,这两个table都可以有多个值。值得注意的是。这两个table必须是”node“。不能是任何其他的。
以Unet结构为例子:
function defineG_unet(input_nc, output_nc, ngf) local netG = nil -- input is (nc) x 256 x 256 -- 初始化时先用“-” local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1) -- input is (ngf) x 128 x 128 local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2) -- input is (ngf * 2) x 64 x 64 local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4) -- input is (ngf * 4) x 32 x 32 local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) -- input is (ngf * 8) x 16 x 16 local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) -- input is (ngf * 8) x 8 x 8 local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) -- input is (ngf * 8) x 4 x 4 local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) -- input is (ngf * 8) x 2 x 2 local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- nn.SpatialBatchNormalization(ngf * 8) -- input is (ngf * 8) x 1 x 1 local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5) -- input is (ngf * 8) x 2 x 2 local d1 = {d1_,e7} - nn.JoinTable(2) local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5) -- input is (ngf * 8) x 4 x 4 local d2 = {d2_,e6} - nn.JoinTable(2) local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5) -- input is (ngf * 8) x 8 x 8 local d3 = {d3_,e5} - nn.JoinTable(2) local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) -- input is (ngf * 8) x 16 x 16 local d4 = {d4_,e4} - nn.JoinTable(2) local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4) -- input is (ngf * 4) x 32 x 32 local d5 = {d5_,e3} - nn.JoinTable(2) local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2) -- input is (ngf * 2) x 64 x 64 local d6 = {d6_,e2} - nn.JoinTable(2) local d7_ = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf) -- input is (ngf) x128 x 128 local d7 = {d7_,e1} - nn.JoinTable(2) local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1) -- input is (nc) x 256 x 256 local o1 = d8 - nn.Tanh() -- 输入节点只有一个,e1 -- 输出节点也只有一个, o1 netG = nn.gModule({e1},{o1}) return netGend
简洁明了,另外,你如果用普通方式构建Unet,github上也有:
https://github.com/dmarnerides/dlt/blob/master/src/models/unet.lua
注意:
以前构建都是用 nn.Module来构建网络。而nngraph包中是用nngraph.Node来构建的网络的,构建出来的网络类型是nn.gModule,nn.gModule是继承自nn.Module的子类nn.Container。nn.Module可以通过上述提到的“-”来 “变成” nngraph.Node类型。
多个输入与多个输出
h1 = - nn.Linear(20,20)h2 = - nn.Linear(10,10)hh1 = h1 - nn.Tanh() - nn.Linear(20,1)hh2 = h2 - nn.Tanh() - nn.Linear(10,1)madd = {hh1,hh2} - nn.CAddTable()oA = madd - nn.Sigmoid()oB = madd - nn.Tanh()gmod = nn.gModule( {h1,h2}, {oA,oB} )
结构图:
输入时,
local out = gmod:forward({input1, input2})local out1 = out[1]local out2 = out[2]--或是 local unpack = unpack or table.unpacklocal out1,out2 = unpack(out)-- 反向传播gmod:backward({grad1, grad2})
其他
输入节点不能是自定义层
h1 = - nn.Linear(20,20)h2 = - nn.Linear(10,10)hh1 = h1 - nn.Tanh() - nn.Linear(20,1)hh2 = h2 - nn.Tanh() - nn.Linear(10,1)madd = {hh1,hh2} - nn.CAddTable()oA = madd - nn.Sigmoid()oB = madd - nn.Tanh()gmod = nn.gModule( {h1,h2}, {oA,oB} )
比如下面这样:
h1 = - nn.Linear(20,20) --这样是错误的,必须要用内置的nn.Module的层,这种自定义的层,重载了 -- nn.Module,会导致出错。h2 = - myLayerhh1 = h1 - nn.Tanh() - nn.Linear(20,1)hh2 = h2 - nn.Tanh() - nn.Linear(10,1)madd = {hh1,hh2} - nn.CAddTable()oA = madd - nn.Sigmoid()oB = madd - nn.Tanh()gmod = nn.gModule( {h1,h2}, {oA,oB} )
Expected nnop.Parameters node, found : nn.MyLayer
当然,中间节点可以是自定义层。
nn.gModule的基本知识
gModule可以允许多个输入,多个输出,当然构建gModule的modules不能构成环。
每个结点可以有多个输入,这些结点的输入顺序存储在node.data.mapindex
中,即每个结点的父结点的指针。
每个结点的输入只能是一个,当然我们可以用类似nn.JoinTable(2)
的方法将其join后,输入到一个结点中。如果输出可以是一张表,存着要输出的各个结点。
另外,node.data.input
是一个list,存储着所有输入,如果输入只有一个,那么只有node.data.input[1]
使用了。
值得注意的是, node.data.gradOutput
是将所有的连接到该结点的反传梯度全部相加,再往前传。
还有一点就是,网络的第一个结点和最后一个结点都是“dummy”的。因为输入和输出可能是多个,所以要用这些dummy的结点来分别处理多个输入和输出。比如对于多个输出,则最后一个dummy结点,内部通过split操作,可以将网络分块,每个结点对应这些“网络块”,当然这些小块很可能会有重叠区域。有趣的是,对于多个输入和多个输出,我们反传时,输入可以不要求全部填满,比如只填第一个输入,那么整个网络相当于截取只和第一个输入有关的网络进行更新。
gModule获得某结点的信息
每个结点包含一个module。如果只是单独想看module的信息,那么直接net.modules[i]
就行。而net.fowardnodes
可以获得每个结点的信息。
只要获得了node
,我们主要可以获得结点的输入以及gradOutput.
node.data.inputnode.data.gradOutput
你可能会问,那gradInput呢?结点是没有gradInput的。每个结点首先综合gradOutput,从而得到本结点的gradOutput, 再将gradOuput传入结点内部的module,得到gradInput之后,再将gradInput作为该结点子节点的gradOuput.
由于每个结点存储了一个module node.data.module
以得到存储某个结点的module的各种信息。
--这是获得net.modules[9]的。--这种写法的好处就是,可以获得结点的信息,而不是单独结点内部的module的信息。local ind = 10local latent = nil for indexNode, node in pairs(net.forwardnodes) do if indexNode == ind then if node.data.module then latent = node.data.module.output:clone() -- use it to get the specific module output end end end
简略看看gModule
gModule的初始化
大概步骤:
首先进行检查输入和输出。要求必须是nngraph.Node
类型的。就是上面提到的-
可以把nn.Module变成nngraph.Node类型。
然后对输入再次进行检查,如果输入只有一个,直接inputs[1]:add(innode,true)
,
否则多个输入必须检查每个输入结点不能有子结点。
然后再构建2张图:fg和bg图。fg图用于网络前向传播评估用,bg用于反向传播。
如果输入结点有多个,还要对每个结点进行assert
assert(root.data.module, 'Expected nnop.Parameters node, module not found in node') assert(torch.typename(root.data.module) == 'nnop.Parameters', 'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module))
这也是为什么在刚才的例子中,不能输入结点不能是自定义的层。从代码上简要的看,如果输入结点只有一个的话,应该可以。
最后把每个结点的其他一些信息,加入到gModule中。
function gModule:__init(inputs,outputs) parent.__init(self) -- the graph is defined backwards, we have the output modules as input here -- we will define a dummy output node that connects all output modules -- into itself. This will be the output for the forward graph and -- input point for the backward graph local node local outnode = nngraph.Node({input={}}) for i = 1, utils.tableMaxN(outputs) do node = outputs[i] if torch.typename(node) ~= 'nngraph.Node' then error(utils.expectingNodeErrorMessage(node, 'outputs', i)) end outnode:add(node, true) end for i = 1, utils.tableMaxN(inputs) do node = inputs[i] if torch.typename(node) ~= 'nngraph.Node' then error(utils.expectingNodeErrorMessage(node, 'inputs', i)) end end -- We add also a dummy input node. -- The input node will be split to feed the passed input nodes. local innode = nngraph.Node({input={}}) assert(#inputs > 0, "no inputs are not supported") if #inputs == 1 then inputs[1]:add(innode,true) else local splits = {innode:split(#inputs)} for i = 1, #inputs do assert(#inputs[i].children == 0, "an input should have no inputs") end for i = 1, #inputs do inputs[i]:add(splits[i],true) end end -- the backward graph (bg) is for gradients -- the forward graph (fg) is for function evaluation self.bg = outnode:graph() self.fg = self.bg:reverse() -- the complete graph is constructed -- now regenerate the graphs with the additional nodes local roots = self.fg:roots() -- if there are more than one root in the forward graph, then make sure that -- extra roots are parameter nodes if #roots > 1 then local innodeRoot = nil -- first find our innode for _, root in ipairs(roots) do if root.data == innode.data then assert(innodeRoot == nil, 'more than one matching input node found in leaves') innodeRoot = root else assert(root.data.module, 'Expected nnop.Parameters node, module not found in node') assert(torch.typename(root.data.module) == 'nnop.Parameters', 'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module)) end end assert(innodeRoot ~= nil, 'input node not found among roots') self.innode = innodeRoot else assert(#self.fg:roots() == 1, "expecting only one start") self.innode = self.fg:roots()[1] end assert(self.innode.data == innode.data, "expecting the forward innode") self.outnode = outnode self.verbose = false self.nInputs = #inputs -- computation on the graph is done through topsort of forward and backward graphs self.forwardnodes = self.fg:topsort() self.backwardnodes = self.bg:topsort() -- iteratare over all nodes: check, tag and add to container for i,node in ipairs(self.forwardnodes) do -- check for unused inputs or unused split() outputs if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #node.children then local nUnused = node.data.nSplitOutputs - #node.children local debugLabel = node.data.annotations._debugLabel local errStr = "%s of split(%s) outputs from the node declared at %s are unused" error(string.format(errStr, nUnused, node.data.nSplitOutputs, debugLabel)) end -- Check whether any nodes were defined as taking this node as an input, -- but then left dangling and don't connect to the output. If this is -- the case, then they won't be present in forwardnodes, so error out. for successor, _ in pairs(node.data.reverseMap) do local successorIsInGraph = false -- Only need to the part of forwardnodes from i onwards, topological -- sort guarantees it cannot be in the first part. for j = i+1, #self.forwardnodes do -- Compare equality of data tables, as new Node objects have been -- created by processes such as topoological sort, but the -- underlying .data table is shared. if self.forwardnodes[j].data == successor.data then successorIsInGraph = true break end end local errStr = "node declared on %s does not connect to gmodule output" assert(successorIsInGraph, string.format(errStr, successor.data.annotations._debugLabel)) end -- set data.forwardNodeId for node:label() output node.data.forwardNodeId = node.id -- add module to container if node.data.module then self:add(node.data.module) end end self.output = nil self.gradInput = nil if #self.outnode.children > 1 then self.output = self.outnode.data.input endend
gModule的反向传播
主要调用这个函数,可以看到
local function getTotalGradOutput(node) local gradOutput = node.data.gradOutput assert(istable(gradOutput), "expecting gradients to sum") if #gradOutput > 1 then -- Check if we can bypass the allocation, for the special case where all -- gradOutputs but one are zero tensors with an underlying one-element -- storage. Note that for the case that we -- cannot bypass it, this check will only be performed once if not node.data.gradOutputBuffer then local count = 0 local idx = 1 -- Count how many gradOutput are tensors of 1 element filled with zero for i=1,#gradOutput do local zero = torch.isTensor(gradOutput[i]) and gradOutput[i]:storage() ~= nil and gradOutput[i]:storage():size() == 1 and gradOutput[i]:storage()[1] == 0 if not zero then idx = i count = count + 1 end end if count < 2 then -- Return the only non-zero one, or the first one -- if they are all zero return gradOutput[idx] end end node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1]) local gobuff = node.data.gradOutputBuffer nesting.resizeNestedAs(gobuff, gradOutput[1]) nesting.copyNested(gobuff, gradOutput[1]) -- 注释: for i=2,#gradOutput do nesting.addNestedTo(gobuff, gradOutput[i]) end gradOutput = gobuff else gradOutput = gradOutput[1] end return gradOutputend
注释:可以看到这里首先将第一个节点的第一个梯度node.data.gradOutput[1]
作为gobuff
, 然后对于其他的梯度进行addNestedTo
操作。
-- Adds the input to the output.-- The input can contain nested tables.-- The output will contain the same nesting of tables.function nesting.addNestedTo(output, input) if torch.isTensor(output) then output:add(input) --不断累加 else for key, child in pairs(input) do assert(output[key] ~= nil, "missing key") nesting.addNestedTo(output[key], child) end endend
可以看到是不断累加的。得到总的gradOuput后,再传入Node里面的module内,调用
module的updateGradInput来计算。
反向传播的主要代码:
function gModule:updateGradInput(input,gradOutput) local function neteval(node) if node.data.selectindex then assert(not node.data.module, "the selectindex-handling nodes should have no module") assert(#node.children == 1, "only the splitted node should be the input") local child = node.children[1] local go = getTotalGradOutput(node) child.data.gradOutput = child.data.gradOutput or {} assert(#child.data.gradOutput <= 1, "the splitted node should be used only once") -- The data.gradOutput holds the to-be-summed gradients. child.data.gradOutput[1] = child.data.gradOutput[1] or {} assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet") child.data.gradOutput[1][node.data.selectindex] = go else --********得到结点的总的gradOutput************** local gradOutput = getTotalGradOutput(node) -- updateGradInput through this node -- If no module is present, the node behaves like nn.Identity. local gradInput if not node.data.module then gradInput = gradOutput else local input = node.data.input -- a parameter node is captured if input == nil and node.data.module ~= nil then input = {} end if #input == 1 then input = input[1] end --********得到结点存储的module,调用module的updateGradInput --********进行更新。 local module = node.data.module gradInput = module:updateGradInput(input,gradOutput) end -- 反传时,将该结点的梯度传给每个子结点的data.gradOutput -- propagate the output to children for i,child in ipairs(node.children) do child.data.gradOutput = child.data.gradOutput or {} local mapindex = node.data.mapindex[child.data] local gi if #node.children == 1 then gi = gradInput else gi = gradInput[mapindex] end table.insert(child.data.gradOutput,gi) end end if self.verbose then print(' V : ' .. node:label()) end end local outnode = self.outnode if #outnode.children > 1 and #gradOutput ~= #outnode.children then error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children)) end for _,node in ipairs(self.backwardnodes) do local gradOutput = node.data.gradOutput while gradOutput and #gradOutput >0 do table.remove(gradOutput) end end -- Set the starting gradOutput. outnode.data.gradOutput = outnode.data.gradOutput or {} outnode.data.gradOutput[1] = gradOutput for i,node in ipairs(self.backwardnodes) do neteval(node) end assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once") self.gradInput = self.innode.data.gradOutput[1] return self.gradInputend
可以看到,反传时,就是先对每个Node综合一下梯度,然后将该梯度传给该结点的每个子节点。然后对每个结点这样做,最终的梯度就是self.innode.data.gradOutput[1]
- Torch7入门续集补充--- nngraph包的使用
- Torch7入门续集(二)---- 更好的使用Math函数
- Torch7入门续集补充(2)--- 每一层设置不同的学习率(finetuning有用)
- Torch7入门续集(三)----Simple Layers的妙用
- Torch7入门续集(四)----利用Table layers构建灵活的网络
- Torch7入门续集(一)----- 更加深入理解Tensor
- Torch7入门续集(六)----多GPU运行程序
- Torch7入门续集(七)--- clone与net替换某一层
- Torch7入门续集(八)---终结篇----不再写Torch博客了,反正就是难受
- nngraph的问题解决
- torch7入门
- Torch7的使用之基本知识
- 使用Torch nngraph实现LSTM
- 使用Torch nngraph实现LSTM
- Spring MVC 入门续集
- torch7的安装
- [Torch7]的安装
- [Torch7]的安装
- jetty post 提交的数据太大
- POJ 1639 Picnic Planning
- 每天一个shell命令(更新中)
- Python 获取指定目录下级文件
- linux下安装nodejs
- Torch7入门续集补充--- nngraph包的使用
- 数据结构之红黑树(一)——基础分析
- float与double的范围和精度
- jquery中获取元素的几种方式小结
- 函数参数中‘*’的含义
- 数据库笔试面试42——SQL标准中用来调用存储过程的
- python 产生随机数,随机字符串
- 总结
- java 大段代码的变量命名问题