[torch]同时更新多个seqlstm
来源:互联网 发布:json 数字 不带双引号 编辑:程序博客网 时间:2024/06/06 08:52
nn.SeqLSTM 在backward的时候需要一些中间参数, 这些中间参数是由seqLSTM:forward(input)时生成的,并且,每forward一些,这些中间参数就会被重置.
maptable
require 'nn'require 'rnn'require 'os'local batch_size = 5local feat_dim = 6local hidden_size = 4local seq_len = 10local num = 10 local lr = 0.01--------initialize model--[[local model = nn.SeqLSTM(feat_dim,hidden_size)--local model = nn.Sequencer(nn.Linear(feat_dim,hidden_size))model:clearState()torch.save("model_init_seqLSTM.t7", model)os.exit()--]]local model = torch.load("model_init_seqLSTM.t7")local model1 = torch.load("model_init_seqLSTM.t7")local model2= torch.load("model_init_seqLSTM.t7")local model3= torch.load("model_init_seqLSTM.t7")--local params, gradparams = model:getParameters():--local criterion = nn.SequencerCriterion(nn.MSECriterion())-------------input, labellocal input = {}local gradOut = {}for i = 1, num do x = torch.randn(seq_len,batch_size,feat_dim) y = torch.randn(seq_len,batch_size,hidden_size) table.insert(input,x) table.insert(gradOut,y)end-------------maplocal map = nn.MapTable():add(model)local out = map:forward(input)map:backward(input,gradOut)map:updateParameters(lr)map:zeroGradParameters()----------model singleloss = 0--[[out={}for i = 1,num do out[i] = model1:forward(input[i])end--]]out=model1:forward(input[num])for i = num,1,-1 do gradInputs = model1:backward(input[i], gradOut[i]) model1:updateParameters(lr)endmodel1:forget()model1:zeroGradParameters()----------model main(true value)loss = 0out={}for i = 1,num do out = model2:forward(input[i]) gradInputs = model2:backward(input[i], gradOut[i])model2:updateParameters(lr)model2:forget()model2:zeroGradParameters()end----------model main2(true value)loss = 0out={}for i = num,1,-1 do out = model3:forward(input[i]) gradInputs = model3:backward(input[i], gradOut[i])model3:updateParameters(lr)model3:forget()model3:zeroGradParameters()end----------forward againout = map:forward(input)out_single,loss = {},0for i, k in pairs(out) do out1 = model1:forward(input[i]) out2 = model2:forward(input[i]) out3 = model3:forward(input[i]) print(i) --print(out2) --true value --print(out3) --true value2 --print(k) --maptable --print(out1) --model single print(out2+out3) print(k*2) --maptable --print(out1*2) --model single (this one is quite different from above two methods)end
new way
require 'nn'require 'rnn'require 'os'local batch_size = 5local feat_dim = 6local hidden_size = 4local seq_len = 10local num = 2 local lr = 0.01--------initialize model--[[local model = nn.SeqLSTM(feat_dim,hidden_size)--local model = nn.Sequencer(nn.Linear(feat_dim,hidden_size))model:clearState()torch.save("model_init_seqLSTM.t7", model)os.exit()--]]local model1 = torch.load("model_init_seqLSTM.t7")--local params, gradparams = model:getParameters()--local criterion = nn.SequencerCriterion(nn.MSECriterion())-------------input, labellocal input = {}local gradOut = {}for i = 1, num do x = torch.randn(seq_len,batch_size,feat_dim) y = torch.randn(seq_len,batch_size,hidden_size) table.insert(input,x) table.insert(gradOut,y)end-------------------new waylocal models,out,gradInputs = {},{},{}for i = num,1,-1 do local model = torch.load("model_init_seqLSTM.t7") models[i] = model if i < num then params_cur, gradParams_cur = models[i]:getParameters() params_updated, gradParams_updated = models[i+1]:getParameters() for j = 1, (#params_cur)[1] do params_cur[j] = params_updated[j] gradParams_cur[j] = gradParams_updated[j] end end out[i] = models[i]:forward(input[i]) gradInputs[i] = models[i]:backward(input[i],gradOut[i]) models[i]:updateParameters(0.01) models[i]:forget() models[i]:zeroGradParameters()endparams_updated, gradParams_updated = models[1]:getParameters()for i = 2,num do params_cur, gradParams_cur = models[i]:getParameters() for j = 1, (#params_cur)[1] do params_cur[j] = params_updated[j] gradParams_cur[j] = gradParams_updated[j] endend----------------------true oneloss = 0out={}for i = num,1,-1 do out = model1:forward(input[i]) gradInputs = model1:backward(input[i], gradOut[i]) model1:updateParameters(lr) model1:forget() model1:zeroGradParameters()end----------check resultsfor i = 1,num do out0 = models[i]:forward(input[i]) out1 = model1:forward(input[i]) print(i,out0,out1)end
阅读全文
0 0
- [torch]同时更新多个seqlstm
- 数据库多个session同时更新一行
- hibernate hql 同时更新多个字段
- 数据库同时更新多个字段
- 在同一个数据集中同时更新多表..............
- Coolite Cool Study 2 同时更新多个Tab
- oracle 同时更新(update)多个字段多个值
- oracle 同时更新(update)多个字段多个值
- oracle 同时更新(update)多个字段多个值
- oracle 同时更新(update)多个字段多个值
- oracle 同时更新(update)多个字段多个值
- torch 变量更新
- 同时更新多条语句
- 同时更新多条数据
- 多个DW同时更新,且表中有关系存在,需要注意的击点问题。
- 多个DW同时更新,且表中有关系存在,需要注意的击点问题。
- 让你的系统能够同时访问多个网段的IP [2013-08-07更新]
- 同时删除多个文件
- 编写App测试用例的关注点
- 200行的Node爬虫花了半天的时间把网易云上的30万首歌曲信息都抓取回来
- 织梦网站安全防范操作
- iframe 引用微信公众号文章图片不显示问题
- CURL命令使用
- [torch]同时更新多个seqlstm
- UICollectionViewLayout的简单使用(简单瀑布流)
- ELK-5.4.1和x-pack权限控制 安装指导
- SO_SNDTIMEO和SO_RCVTIMEO
- DB2、Oracle命令行导入/导出数据
- Tomcat源码解析(9)
- 报错com.neenbedankt.android-apt not found如何解决
- 简单自定义选择按钮(switchDemo)
- 整数中1出现的次数