[torch]maskzero-brnn

来源:互联网 发布:部落冲突七本满防数据 编辑:程序博客网 时间:2024/05/20 11:34

https://www.bountysource.com/teams/element-research/issues?tracker_ids=13818311

require 'nn'require 'rnn'inputdim = 5outputdim = 3seq = 4batch = 2--seqlen x batchsize x inputsizemodel1 = nn.SeqBRNN(inputdim, outputdim)model1.forwardModule:maskZero()model1.backwardModule:maskZero()--model2 = nn.MaskZero(nn.SeqBRNN(inputdim, outputdim),1)input = torch.rand(seq,batch,inputdim)input[1][1]:fill(0)input[seq][2]:fill(0)print(input)print(model1:forward(input))

res:

(1,.,.) =   0.0000  0.0000  0.0000  0.0000  0.0000  0.2203  0.3390  0.6926  0.8441  0.6361(2,.,.) =   0.6061  0.7323  0.3474  0.5499  0.6236  0.9515  0.2123  0.3901  0.2063  0.1385(3,.,.) =   0.6000  0.7315  0.5589  0.9840  0.1418  0.1604  0.7989  0.6146  0.6593  0.0978(4,.,.) =   0.9620  0.0543  0.5116  0.7920  0.9306  0.0000  0.0000  0.0000  0.0000  0.0000[torch.DoubleTensor of size 4x2x5](1,.,.) =   0.0000  0.0000  0.0000  0.4691  0.0980 -0.4581(2,.,.) =   0.6086  0.0915 -0.5804  0.6128 -0.0166 -0.3646(3,.,.) =   0.5398 -0.0439 -0.5339  0.4330 -0.1285 -0.1740(4,.,.) =   0.5330 -0.1697 -0.4250  0.0000  0.0000  0.0000[torch.DoubleTensor of size 4x2x3]
原创粉丝点击