[torch]Save initial state(fastlstm)
来源:互联网 发布:硫磺岛战役知乎 编辑:程序博客网 时间:2024/06/05 19:58
installpath/torch/rnn/Fastlstm.lua
before
local FastLSTM, parent = torch.class("nn.FastLSTM", "nn.LSTM")-- set this to true to have it use nngraph instead of nn-- setting this to true can make your next FastLSTM significantly fasterFastLSTM.usenngraph = falseFastLSTM.bn = falsefunction FastLSTM:__init(inputSize, outputSize, rho, eps, momentum, affine) -- initialize batch norm variance with 0.1 self.eps = eps or 0.1 self.momentum = momentum or 0.1 --gamma self.affine = affine == nil and true or affine parent.__init(self, inputSize, outputSize, rho) endfunction FastLSTM:buildModel() -- input : {input, prevOutput, prevCell} -- output : {output, cell} -- Calculate all four gates in one go : input, hidden, forget, output self.i2g = nn.Linear(self.inputSize, 4*self.outputSize) self.o2g = nn.LinearNoBias(self.outputSize, 4*self.outputSize) if self.usenngraph or self.bn then require 'nngraph' return self:nngraphModel() end local para = nn.ParallelTable():add(self.i2g):add(self.o2g) local gates = nn.Sequential() gates:add(nn.NarrowTable(1,2)) gates:add(para) gates:add(nn.CAddTable()) -- Reshape to (batch_size, n_gates, hid_size) -- Then slize the n_gates dimension, i.e dimension 2 gates:add(nn.Reshape(4,self.outputSize)) gates:add(nn.SplitTable(1,2)) local transfer = nn.ParallelTable() transfer:add(nn.Sigmoid()):add(nn.Tanh()):add(nn.Sigmoid()):add(nn.Sigmoid()) gates:add(transfer) local concat = nn.ConcatTable() concat:add(gates):add(nn.SelectTable(3)) local seq = nn.Sequential() seq:add(concat) seq:add(nn.FlattenTable()) -- input, hidden, forget, output, cell -- input gate * hidden state local hidden = nn.Sequential() hidden:add(nn.NarrowTable(1,2)) hidden:add(nn.CMulTable()) -- forget gate * cell local cell = nn.Sequential() local concat = nn.ConcatTable() concat:add(nn.SelectTable(3)):add(nn.SelectTable(5)) cell:add(concat) cell:add(nn.CMulTable()) local nextCell = nn.Sequential() local concat = nn.ConcatTable() concat:add(hidden):add(cell) nextCell:add(concat) nextCell:add(nn.CAddTable()) local concat = nn.ConcatTable() concat:add(nextCell):add(nn.SelectTable(4)) seq:add(concat) seq:add(nn.FlattenTable()) -- nextCell, outputGate local cellAct = nn.Sequential() cellAct:add(nn.SelectTable(1)) cellAct:add(nn.Tanh()) local concat = nn.ConcatTable() concat:add(cellAct):add(nn.SelectTable(2)) local output = nn.Sequential() output:add(concat) output:add(nn.CMulTable()) local concat = nn.ConcatTable() concat:add(output):add(nn.SelectTable(1)) seq:add(concat) return seqend
after
require 'hdf5'local FastLSTM, parent = torch.class("nn.FastLSTM", "nn.LSTM")-- set this to true to have it use nngraph instead of nn-- setting this to true can make your next FastLSTM significantly fasterFastLSTM.usenngraph = falseFastLSTM.bn = falsefunction FastLSTM:__init(inputSize, outputSize, rho, eps, momentum, affine, initialfile, ifLoad) -- initialize batch norm variance with 0.1 self.eps = eps or 0.1 self.momentum = momentum or 0.1 --gamma self.affine = affine == nil and true or affine self.initialfile = initialfile or 0 self.ifLoad = ifLoad --"1" means to load initialfile to initialize. "0" means to save weights to initialfile. parent.__init(self, inputSize, outputSize, rho) endfunction FastLSTM:buildModel() -- input : {input, prevOutput, prevCell} -- output : {output, cell} -- Calculate all four gates in one go : input, hidden, forget, output self.i2g = nn.Linear(self.inputSize, 4*self.outputSize) self.o2g = nn.LinearNoBias(self.outputSize, 4*self.outputSize) if self.initialfile ~= 0 then if self.ifLoad then -- use hdf5 to initialize local myFile = hdf5.open(self.initialfile, 'r') self.i2g.weight = myFile:read('i2g_weight'):all() self.i2g.bias = myFile:read('i2g_bias'):all() self.i2g.gradWeight = myFile:read('i2g_gradWeight'):all() self.i2g.gradBias = myFile:read('i2g_gradBias'):all() self.o2g.weight = myFile:read('o2g_weight'):all() self.o2g.gradWeight = myFile:read('o2g_gradWeight'):all() myFile:close() else local myFile = hdf5.open(self.initialfile, 'w') myFile:write('i2g_weight',self.i2g.weight) myFile:write('i2g_bias',self.i2g.bias) myFile:write('i2g_gradWeight',self.i2g.gradWeight) myFile:write('i2g_gradBias',self.i2g.gradBias) myFile:write('o2g_weight',self.o2g.weight) myFile:write('o2g_gradWeight',self.o2g.gradWeight) myFile:close() end end if self.usenngraph or self.bn then require 'nngraph' return self:nngraphModel() end --...end
after
cd ~/installpath/torch/rnnrm -r build/luarocks make rocks/rnn-scm-1.rockspec
0 0
- [torch]Save initial state(fastlstm)
- activity state save
- Save/Restore Your Activity State
- 0009-APP-Activity-Save-Restore-State
- torch
- Torch
- Torch
- Torch
- Torch
- Torch
- Torch
- Initial
- Initial
- Android ApiDemos示例解析(14):App->Activity->Save & Restore State
- Android API Demos学习(2) - Save & Restore State
- Android ApiDemos示例解析(14):App->Activity->Save & Restore State
- tensorflow学习笔记(三十七):如何自定义LSTM的initial state
- save
- 允许局域网内其他主机访问本地MySql数据库
- Struts2接收请求参数1
- 手机app测试方法(一)基本流程
- 快速排序
- 蓝桥杯 算法提高 9-1九宫格
- [torch]Save initial state(fastlstm)
- 每日一篇算法题——数独
- PHP源码解析笔记2-生命周期和Zend引擎
- Java四种权限修饰符(public, default, protected, private)的用法和对比
- hadoop原生版安装部署---3.hdfs
- HibernateSpatial4.3+postgresql的使用
- 一,Launch启动图片,隐藏上面的状态栏
- centos6.5 Building the main Guest Additions module
- 仿微信输入框静态