Torch7入门续集补充--- nngraph包的使用

来源:互联网 发布:淘宝折800 编辑:程序博客网 时间:2024/05/17 07:35

构建方法

nngraph包在构建更加复杂的网络极其有用。毕竟是有点类似”静态图“了。
简单来说就是以前加网络需要不断add,现在用了nngraph,只要不断”“就行了。

h1 = - nn.Linear(20,10)h2 = h1     - nn.Tanh()     - nn.Linear(10,10)     - nn.Tanh()     - nn.Linear(10, 1)mlp = nn.gModule({h1}, {h2})

注意点:
1. 刚开始时需要用”-“来初始化。
2. 在nn.gModule中写入两个table,第一个table表示输入节点,第二个是输出节点。
当然,这两个table都可以有多个值。值得注意的是。这两个table必须是”node“。不能是任何其他的。

以Unet结构为例子:

function defineG_unet(input_nc, output_nc, ngf)    local netG = nil    -- input is (nc) x 256 x 256    -- 初始化时先用“-”    local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)    -- input is (ngf) x 128 x 128    local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2)    -- input is (ngf * 2) x 64 x 64    local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4)    -- input is (ngf * 4) x 32 x 32    local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)    -- input is (ngf * 8) x 16 x 16    local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)    -- input is (ngf * 8) x 8 x 8    local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)    -- input is (ngf * 8) x 4 x 4    local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)    -- input is (ngf * 8) x 2 x 2    local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- nn.SpatialBatchNormalization(ngf * 8)    -- input is (ngf * 8) x 1 x 1    local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)    -- input is (ngf * 8) x 2 x 2    local d1 = {d1_,e7} - nn.JoinTable(2)    local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)    -- input is (ngf * 8) x 4 x 4    local d2 = {d2_,e6} - nn.JoinTable(2)    local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)    -- input is (ngf * 8) x 8 x 8    local d3 = {d3_,e5} - nn.JoinTable(2)    local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)    -- input is (ngf * 8) x 16 x 16    local d4 = {d4_,e4} - nn.JoinTable(2)    local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4)    -- input is (ngf * 4) x 32 x 32    local d5 = {d5_,e3} - nn.JoinTable(2)    local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2)    -- input is (ngf * 2) x 64 x 64    local d6 = {d6_,e2} - nn.JoinTable(2)    local d7_ = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf)    -- input is (ngf) x128 x 128    local d7 = {d7_,e1} - nn.JoinTable(2)    local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)    -- input is (nc) x 256 x 256    local o1 = d8 - nn.Tanh()    -- 输入节点只有一个,e1    -- 输出节点也只有一个, o1    netG = nn.gModule({e1},{o1})    return netGend

简洁明了,另外,你如果用普通方式构建Unet,github上也有:
https://github.com/dmarnerides/dlt/blob/master/src/models/unet.lua
注意:
以前构建都是用 nn.Module来构建网络。而nngraph包中是用nngraph.Node来构建的网络的,构建出来的网络类型是nn.gModule,nn.gModule是继承自nn.Module的子类nn.Container。nn.Module可以通过上述提到的“-”来 “变成” nngraph.Node类型。

多个输入与多个输出

h1 = - nn.Linear(20,20)h2 = - nn.Linear(10,10)hh1 = h1 - nn.Tanh() - nn.Linear(20,1)hh2 = h2 - nn.Tanh() - nn.Linear(10,1)madd = {hh1,hh2} - nn.CAddTable()oA = madd - nn.Sigmoid()oB = madd - nn.Tanh()gmod = nn.gModule( {h1,h2}, {oA,oB} )

结构图:
这里写图片描述
输入时,

local out = gmod:forward({input1, input2})local out1 = out[1]local out2 = out[2]--或是 local unpack = unpack or table.unpacklocal out1,out2 = unpack(out)-- 反向传播gmod:backward({grad1, grad2})

其他

输入节点不能是自定义层

h1 = - nn.Linear(20,20)h2 = - nn.Linear(10,10)hh1 = h1 - nn.Tanh() - nn.Linear(20,1)hh2 = h2 - nn.Tanh() - nn.Linear(10,1)madd = {hh1,hh2} - nn.CAddTable()oA = madd - nn.Sigmoid()oB = madd - nn.Tanh()gmod = nn.gModule( {h1,h2}, {oA,oB} )

比如下面这样:

h1 = - nn.Linear(20,20) --这样是错误的,必须要用内置的nn.Module的层,这种自定义的层,重载了 -- nn.Module,会导致出错。h2 = - myLayerhh1 = h1 - nn.Tanh() - nn.Linear(20,1)hh2 = h2 - nn.Tanh() - nn.Linear(10,1)madd = {hh1,hh2} - nn.CAddTable()oA = madd - nn.Sigmoid()oB = madd - nn.Tanh()gmod = nn.gModule( {h1,h2}, {oA,oB} )

Expected nnop.Parameters node, found : nn.MyLayer
当然,中间节点可以是自定义层。

nn.gModule的基本知识

gModule可以允许多个输入,多个输出,当然构建gModule的modules不能构成环。
每个结点可以有多个输入,这些结点的输入顺序存储在node.data.mapindex中,即每个结点的父结点的指针。

每个结点的输入只能是一个,当然我们可以用类似nn.JoinTable(2)的方法将其join后,输入到一个结点中。如果输出可以是一张表,存着要输出的各个结点。

另外,node.data.input是一个list,存储着所有输入,如果输入只有一个,那么只有node.data.input[1]使用了。
值得注意的是, node.data.gradOutput是将所有的连接到该结点的反传梯度全部相加,再往前传。

还有一点就是,网络的第一个结点和最后一个结点都是“dummy”的。因为输入和输出可能是多个,所以要用这些dummy的结点来分别处理多个输入和输出。比如对于多个输出,则最后一个dummy结点,内部通过split操作,可以将网络分块,每个结点对应这些“网络块”,当然这些小块很可能会有重叠区域。有趣的是,对于多个输入和多个输出,我们反传时,输入可以不要求全部填满,比如只填第一个输入,那么整个网络相当于截取只和第一个输入有关的网络进行更新。

gModule获得某结点的信息

每个结点包含一个module。如果只是单独想看module的信息,那么直接net.modules[i]就行。而net.fowardnodes可以获得每个结点的信息。
只要获得了node,我们主要可以获得结点的输入以及gradOutput.

node.data.inputnode.data.gradOutput

你可能会问,那gradInput呢?结点是没有gradInput的。每个结点首先综合gradOutput,从而得到本结点的gradOutput, 再将gradOuput传入结点内部的module,得到gradInput之后,再将gradInput作为该结点子节点的gradOuput.

由于每个结点存储了一个module
node.data.module以得到存储某个结点的module的各种信息。

--这是获得net.modules[9]的。--这种写法的好处就是,可以获得结点的信息,而不是单独结点内部的module的信息。local ind = 10local latent = nil for indexNode, node in pairs(net.forwardnodes) do     if indexNode == ind then         if node.data.module then             latent = node.data.module.output:clone()  -- use it to get the specific module output         end               end end

简略看看gModule

gModule的初始化

大概步骤:
首先进行检查输入和输出。要求必须是nngraph.Node类型的。就是上面提到的-可以把nn.Module变成nngraph.Node类型。
然后对输入再次进行检查,如果输入只有一个,直接inputs[1]:add(innode,true)
否则多个输入必须检查每个输入结点不能有子结点。
然后再构建2张图:fg和bg图。fg图用于网络前向传播评估用,bg用于反向传播。
如果输入结点有多个,还要对每个结点进行assert

            assert(root.data.module, 'Expected nnop.Parameters node, module not found in node')            assert(torch.typename(root.data.module) == 'nnop.Parameters',                  'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module))

这也是为什么在刚才的例子中,不能输入结点不能是自定义的层。从代码上简要的看,如果输入结点只有一个的话,应该可以。
最后把每个结点的其他一些信息,加入到gModule中。

function gModule:__init(inputs,outputs)   parent.__init(self)   -- the graph is defined backwards, we have the output modules as input here   -- we will define a dummy output node that connects all output modules   -- into itself. This will be the output for the forward graph and   -- input point for the backward graph   local node   local outnode = nngraph.Node({input={}})   for i = 1, utils.tableMaxN(outputs) do      node = outputs[i]      if torch.typename(node) ~= 'nngraph.Node' then         error(utils.expectingNodeErrorMessage(node, 'outputs', i))      end      outnode:add(node, true)   end   for i = 1, utils.tableMaxN(inputs) do      node = inputs[i]      if torch.typename(node) ~= 'nngraph.Node' then         error(utils.expectingNodeErrorMessage(node, 'inputs', i))      end   end   -- We add also a dummy input node.   -- The input node will be split to feed the passed input nodes.   local innode = nngraph.Node({input={}})   assert(#inputs > 0, "no inputs are not supported")   if #inputs == 1 then      inputs[1]:add(innode,true)   else      local splits = {innode:split(#inputs)}      for i = 1, #inputs do         assert(#inputs[i].children == 0, "an input should have no inputs")      end      for i = 1, #inputs do         inputs[i]:add(splits[i],true)      end   end   -- the backward graph (bg) is for gradients   -- the forward graph (fg) is for function evaluation   self.bg = outnode:graph()   self.fg = self.bg:reverse()   -- the complete graph is constructed   -- now regenerate the graphs with the additional nodes   local roots = self.fg:roots()   -- if there are more than one root in the forward graph, then make sure that   -- extra roots are parameter nodes   if #roots > 1 then      local innodeRoot = nil      -- first find our innode      for _, root in ipairs(roots) do         if root.data == innode.data then            assert(innodeRoot == nil, 'more than one matching input node found in leaves')            innodeRoot = root         else            assert(root.data.module, 'Expected nnop.Parameters node, module not found in node')            assert(torch.typename(root.data.module) == 'nnop.Parameters',                  'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module))         end      end      assert(innodeRoot ~= nil, 'input node not found among roots')      self.innode = innodeRoot   else      assert(#self.fg:roots() == 1, "expecting only one start")      self.innode = self.fg:roots()[1]   end   assert(self.innode.data == innode.data, "expecting the forward innode")   self.outnode = outnode   self.verbose = false   self.nInputs = #inputs   -- computation on the graph is done through topsort of forward and backward graphs   self.forwardnodes = self.fg:topsort()   self.backwardnodes = self.bg:topsort()   -- iteratare over all nodes: check, tag and add to container   for i,node in ipairs(self.forwardnodes) do      -- check for unused inputs or unused split() outputs      if node.data.nSplitOutputs and node.data.nSplitOutputs ~=  #node.children then         local nUnused = node.data.nSplitOutputs - #node.children         local debugLabel = node.data.annotations._debugLabel         local errStr =            "%s of split(%s) outputs from the node declared at %s are unused"         error(string.format(errStr, nUnused, node.data.nSplitOutputs,                             debugLabel))      end      -- Check whether any nodes were defined as taking this node as an input,      -- but then left dangling and don't connect to the output. If this is      -- the case, then they won't be present in forwardnodes, so error out.      for successor, _ in pairs(node.data.reverseMap) do         local successorIsInGraph = false         -- Only need to the part of forwardnodes from i onwards, topological         -- sort guarantees it cannot be in the first part.         for j = i+1, #self.forwardnodes do            -- Compare equality of data tables, as new Node objects have been            -- created by processes such as topoological sort, but the            -- underlying .data table is shared.            if self.forwardnodes[j].data == successor.data then               successorIsInGraph = true               break            end         end         local errStr =            "node declared on %s does not connect to gmodule output"         assert(successorIsInGraph,                string.format(errStr, successor.data.annotations._debugLabel))      end      -- set data.forwardNodeId for node:label() output      node.data.forwardNodeId = node.id      -- add module to container      if node.data.module then         self:add(node.data.module)      end   end   self.output = nil   self.gradInput = nil   if #self.outnode.children > 1 then      self.output = self.outnode.data.input   endend

gModule的反向传播

主要调用这个函数,可以看到

local function getTotalGradOutput(node)   local gradOutput = node.data.gradOutput   assert(istable(gradOutput), "expecting gradients to sum")   if #gradOutput > 1 then      -- Check if we can bypass the allocation, for the special case where all      -- gradOutputs but one are zero tensors with an underlying one-element      -- storage. Note that for the case that we      -- cannot bypass it, this check will only be performed once      if not node.data.gradOutputBuffer then         local count = 0         local idx = 1         -- Count how many gradOutput are tensors of 1 element filled with zero         for i=1,#gradOutput do            local zero = torch.isTensor(gradOutput[i]) and                         gradOutput[i]:storage() ~= nil and                         gradOutput[i]:storage():size() == 1 and                         gradOutput[i]:storage()[1] == 0            if not zero then               idx = i               count = count + 1            end         end         if count < 2 then            -- Return the only non-zero one, or the first one            -- if they are all zero            return gradOutput[idx]         end      end      node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1])      local gobuff = node.data.gradOutputBuffer      nesting.resizeNestedAs(gobuff, gradOutput[1])      nesting.copyNested(gobuff, gradOutput[1])      -- 注释:      for i=2,#gradOutput do         nesting.addNestedTo(gobuff, gradOutput[i])      end      gradOutput = gobuff   else      gradOutput = gradOutput[1]   end   return gradOutputend

注释:可以看到这里首先将第一个节点的第一个梯度node.data.gradOutput[1]作为gobuff, 然后对于其他的梯度进行addNestedTo操作。

-- Adds the input to the output.-- The input can contain nested tables.-- The output will contain the same nesting of tables.function nesting.addNestedTo(output, input)   if torch.isTensor(output) then      output:add(input) --不断累加   else      for key, child in pairs(input) do         assert(output[key] ~= nil, "missing key")         nesting.addNestedTo(output[key], child)      end   endend

可以看到是不断累加的。得到总的gradOuput后,再传入Node里面的module内,调用
module的updateGradInput来计算。

反向传播的主要代码:

function gModule:updateGradInput(input,gradOutput)   local function neteval(node)      if node.data.selectindex then         assert(not node.data.module, "the selectindex-handling nodes should have no module")         assert(#node.children == 1, "only the splitted node should be the input")         local child = node.children[1]         local go = getTotalGradOutput(node)         child.data.gradOutput = child.data.gradOutput or {}         assert(#child.data.gradOutput <= 1, "the splitted node should be used only once")         -- The data.gradOutput holds the to-be-summed gradients.         child.data.gradOutput[1] = child.data.gradOutput[1] or {}         assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet")         child.data.gradOutput[1][node.data.selectindex] = go      else      --********得到结点的总的gradOutput**************         local gradOutput = getTotalGradOutput(node)         -- updateGradInput through this node         -- If no module is present, the node behaves like nn.Identity.         local gradInput         if not node.data.module then            gradInput = gradOutput         else            local input = node.data.input            -- a parameter node is captured            if input == nil and node.data.module ~= nil then               input = {}            end            if #input == 1 then               input = input[1]            end            --********得到结点存储的module,调用module的updateGradInput            --********进行更新。            local module = node.data.module            gradInput = module:updateGradInput(input,gradOutput)         end         -- 反传时,将该结点的梯度传给每个子结点的data.gradOutput         -- propagate the output to children         for i,child in ipairs(node.children) do            child.data.gradOutput = child.data.gradOutput or {}            local mapindex = node.data.mapindex[child.data]            local gi            if #node.children == 1 then               gi = gradInput            else               gi = gradInput[mapindex]            end            table.insert(child.data.gradOutput,gi)         end      end      if self.verbose then         print(' V : ' .. node:label())      end   end   local outnode = self.outnode   if #outnode.children > 1 and #gradOutput ~= #outnode.children then      error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))   end   for _,node in ipairs(self.backwardnodes) do      local gradOutput = node.data.gradOutput      while gradOutput and #gradOutput >0 do         table.remove(gradOutput)      end   end   -- Set the starting gradOutput.   outnode.data.gradOutput = outnode.data.gradOutput or {}   outnode.data.gradOutput[1] = gradOutput   for i,node in ipairs(self.backwardnodes) do      neteval(node)   end   assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once")   self.gradInput = self.innode.data.gradOutput[1]   return self.gradInputend

可以看到,反传时,就是先对每个Node综合一下梯度,然后将该梯度传给该结点的每个子节点。然后对每个结点这样做,最终的梯度就是self.innode.data.gradOutput[1]