[torch]同时更新多个seqlstm

来源:互联网 发布:json 数字 不带双引号 编辑:程序博客网 时间:2024/06/06 08:52

nn.SeqLSTM 在backward的时候需要一些中间参数, 这些中间参数是由seqLSTM:forward(input)时生成的,并且,每forward一些,这些中间参数就会被重置.

maptable

require 'nn'require 'rnn'require 'os'local batch_size = 5local feat_dim = 6local hidden_size = 4local seq_len = 10local num = 10 local lr = 0.01--------initialize model--[[local model = nn.SeqLSTM(feat_dim,hidden_size)--local model = nn.Sequencer(nn.Linear(feat_dim,hidden_size))model:clearState()torch.save("model_init_seqLSTM.t7", model)os.exit()--]]local model = torch.load("model_init_seqLSTM.t7")local model1 = torch.load("model_init_seqLSTM.t7")local model2= torch.load("model_init_seqLSTM.t7")local model3= torch.load("model_init_seqLSTM.t7")--local params,  gradparams =  model:getParameters():--local criterion = nn.SequencerCriterion(nn.MSECriterion())-------------input, labellocal input = {}local gradOut = {}for i = 1, num do        x = torch.randn(seq_len,batch_size,feat_dim)        y = torch.randn(seq_len,batch_size,hidden_size)        table.insert(input,x)        table.insert(gradOut,y)end-------------maplocal map = nn.MapTable():add(model)local out = map:forward(input)map:backward(input,gradOut)map:updateParameters(lr)map:zeroGradParameters()----------model singleloss = 0--[[out={}for i = 1,num do    out[i] = model1:forward(input[i])end--]]out=model1:forward(input[num])for i = num,1,-1 do        gradInputs = model1:backward(input[i], gradOut[i])    model1:updateParameters(lr)endmodel1:forget()model1:zeroGradParameters()----------model main(true value)loss = 0out={}for i = 1,num do    out = model2:forward(input[i])        gradInputs = model2:backward(input[i], gradOut[i])model2:updateParameters(lr)model2:forget()model2:zeroGradParameters()end----------model main2(true value)loss = 0out={}for i = num,1,-1 do    out = model3:forward(input[i])        gradInputs = model3:backward(input[i], gradOut[i])model3:updateParameters(lr)model3:forget()model3:zeroGradParameters()end----------forward againout = map:forward(input)out_single,loss = {},0for i, k in pairs(out) do        out1 = model1:forward(input[i])    out2 = model2:forward(input[i])    out3 = model3:forward(input[i])    print(i)    --print(out2) --true value    --print(out3) --true value2    --print(k)    --maptable    --print(out1) --model single    print(out2+out3)        print(k*2)    --maptable    --print(out1*2) --model single (this one is quite different from above two methods)end

new way

require 'nn'require 'rnn'require 'os'local batch_size = 5local feat_dim = 6local hidden_size = 4local seq_len = 10local num = 2 local lr = 0.01--------initialize model--[[local model = nn.SeqLSTM(feat_dim,hidden_size)--local model = nn.Sequencer(nn.Linear(feat_dim,hidden_size))model:clearState()torch.save("model_init_seqLSTM.t7", model)os.exit()--]]local model1 = torch.load("model_init_seqLSTM.t7")--local params,  gradparams =  model:getParameters()--local criterion = nn.SequencerCriterion(nn.MSECriterion())-------------input, labellocal input = {}local gradOut = {}for i = 1, num do    x = torch.randn(seq_len,batch_size,feat_dim)    y = torch.randn(seq_len,batch_size,hidden_size)    table.insert(input,x)    table.insert(gradOut,y)end-------------------new waylocal models,out,gradInputs = {},{},{}for i = num,1,-1 do    local model = torch.load("model_init_seqLSTM.t7")    models[i] = model    if i < num then        params_cur, gradParams_cur = models[i]:getParameters()        params_updated, gradParams_updated = models[i+1]:getParameters()        for j = 1, (#params_cur)[1] do            params_cur[j] = params_updated[j]            gradParams_cur[j] = gradParams_updated[j]        end    end    out[i] = models[i]:forward(input[i])    gradInputs[i] = models[i]:backward(input[i],gradOut[i])    models[i]:updateParameters(0.01)    models[i]:forget()    models[i]:zeroGradParameters()endparams_updated, gradParams_updated = models[1]:getParameters()for i = 2,num do    params_cur, gradParams_cur = models[i]:getParameters()        for j = 1, (#params_cur)[1] do            params_cur[j] = params_updated[j]                gradParams_cur[j] = gradParams_updated[j]        endend----------------------true oneloss = 0out={}for i = num,1,-1 do        out = model1:forward(input[i])        gradInputs = model1:backward(input[i], gradOut[i])    model1:updateParameters(lr)    model1:forget()    model1:zeroGradParameters()end----------check resultsfor i = 1,num do    out0 = models[i]:forward(input[i])    out1 = model1:forward(input[i])    print(i,out0,out1)end
原创粉丝点击