[DL]基于Pytorch的seq2seq模型
来源:互联网 发布:水平角观测记录表数据 编辑:程序博客网 时间:2024/06/06 11:04
import torchimport torch.nn as nnclass RNNEncode(nn.Module): def __init__(self): super(RNNEncode,self).__init__() self.input_size=1000 self.hidden_size=10 self.out_size=10 self.i2o=nn.Linear(self.input_size+self.hidden_size,self.out_size) self.i2h=nn.Linear(self.input_size+self.hidden_size,self.hidden_size) self.h0=torch.autograd.Variable(torch.randn(1,self.hidden_size)) def forward(self,words): global out for i in range(len(words)): input=torch.cat((words[i].view(1,-1),self.h0.view(1,-1)),dim=1) out=self.i2o(input) self.h0=self.i2h(input) return out,self.h0class RNNDecode(nn.Module): def __init__(self): super(RNNDecode,self).__init__() self.input_size=10 self.hidden_size=10 self.out_size=10 self.i2o=nn.Linear(self.input_size+self.hidden_size,self.out_size) self.i2h=nn.Linear(self.input_size+self.hidden_size,self.hidden_size) self.softmax=nn.Softmax() def forward(self,o,hidden): seq=[] for i in range(5): input=torch.cat((o.view(1,-1),hidden.view(1,-1)),dim=1) out=self.i2o(input.view(1,-1)) hidden=self.i2h(input) seq.append(torch.max(self.softmax(out),1)[1]) return seq model=RNNEncode()a=torch.autograd.Variable(torch.randn(5,1000))# print(a)o,c=model(a)print(o.size())model1=RNNDecode()seq=model1(o,c)print(seq)print('done')
阅读全文
0 0
- [DL]基于Pytorch的seq2seq模型
- [DL]基于pytorch的Elman RNN语言模型
- pytorch入门(3)pytorch-seq2seq模型
- [DL]基于Pytorch的Linear classified model
- [DL]基于Pytorch的N-gram Language Model
- Tensorflow 自动文摘: 基于Seq2Seq+Attention模型的Textsum模型
- Tensorflow 自动文摘: 基于Seq2Seq+Attention模型的Textsum模型
- Tensorflow 自动文摘: 基于Seq2Seq+Attention模型的Textsum模型
- 深度学习的seq2seq模型
- 深度学习的seq2seq模型
- 深度学习的seq2seq模型
- 深度学习的seq2seq模型
- <模型汇总-7>基于CNN的Seq2Seq模型-Convolutional Sequence to Sequence Learning
- 模型汇总17 基于Depthwise Separable Convolutions的Seq2Seq模型_SliceNet原理解析
- pytorch 模型的加载
- seq2seq模型
- seq2seq模型
- seq2seq模型
- 搭建SpringMVC项目——02 配置文件pom.xml
- 未来是属于 ARM 为代表的精简指令集还是 x86 为代表的复杂指令集?
- OMCS ——卓尔不群的网络语音视频聊天框架(跨平台)
- 难以相信!比尔盖茨当选中国外籍院士,核能外交让盖茨实至名归!
- 《阿里巴巴Java开发规约》插件使用详细指南
- [DL]基于Pytorch的seq2seq模型
- Android进程间通信的几种方式
- 【CSS.DIV】HTML<li>标签
- Kaldi中nnet3进行语音识别过程中用到的部分工具集锦!!!
- iOS字体大小适配的几种方法
- css实现六边形图片(最简单易懂方法实现高逼格图片展示)
- Kotlin 包和 import 语句使用
- JAVA基础篇
- linux防火墙的配置iptables