[DL]基于Pytorch的seq2seq模型

来源:互联网 发布:水平角观测记录表数据 编辑:程序博客网 时间:2024/06/06 11:04
import torchimport torch.nn as nnclass RNNEncode(nn.Module):    def __init__(self):        super(RNNEncode,self).__init__()        self.input_size=1000        self.hidden_size=10        self.out_size=10        self.i2o=nn.Linear(self.input_size+self.hidden_size,self.out_size)        self.i2h=nn.Linear(self.input_size+self.hidden_size,self.hidden_size)        self.h0=torch.autograd.Variable(torch.randn(1,self.hidden_size))    def forward(self,words):        global out        for i in range(len(words)):            input=torch.cat((words[i].view(1,-1),self.h0.view(1,-1)),dim=1)            out=self.i2o(input)            self.h0=self.i2h(input)        return out,self.h0class RNNDecode(nn.Module):    def __init__(self):        super(RNNDecode,self).__init__()        self.input_size=10        self.hidden_size=10        self.out_size=10        self.i2o=nn.Linear(self.input_size+self.hidden_size,self.out_size)        self.i2h=nn.Linear(self.input_size+self.hidden_size,self.hidden_size)        self.softmax=nn.Softmax()    def forward(self,o,hidden):        seq=[]        for i in range(5):            input=torch.cat((o.view(1,-1),hidden.view(1,-1)),dim=1)            out=self.i2o(input.view(1,-1))            hidden=self.i2h(input)            seq.append(torch.max(self.softmax(out),1)[1])        return seq model=RNNEncode()a=torch.autograd.Variable(torch.randn(5,1000))# print(a)o,c=model(a)print(o.size())model1=RNNDecode()seq=model1(o,c)print(seq)print('done')
阅读全文
0 0
原创粉丝点击