Pytorch LSTM 时间序列预测
来源:互联网 发布:2015年水产品出口数据 编辑:程序博客网 时间:2024/05/17 07:20
import torchimport torch.nn as nnfrom torch.autograd import *import torch.optim as optimimport torch.nn.functional as Fimport matplotlib.pyplot as pltimport numpy as npdef SeriesGen(N): x = torch.arange(1,N,0.01) return torch.sin(x)def trainDataGen(seq,k): dat = list() L = len(seq) for i in range(L-k-1): indat = seq[i:i+k] outdat = seq[i+1:i+k+1] dat.append((indat,outdat)) return datdef ToVariable(x): tmp = torch.FloatTensor(x) return Variable(tmp)y = SeriesGen(10)dat = trainDataGen(y.numpy(),10)class LSTMpred(nn.Module): def __init__(self,input_size,hidden_dim): super(LSTMpred,self).__init__() self.input_dim = input_size self.hidden_dim = hidden_dim self.lstm = nn.LSTM(input_size,hidden_dim) self.hidden2out = nn.Linear(hidden_dim,1) self.hidden = self.init_hidden() def init_hidden(self): return (Variable(torch.zeros(1, 1, self.hidden_dim)), Variable(torch.zeros(1, 1, self.hidden_dim))) def forward(self,seq): lstm_out, self.hidden = self.lstm( seq.view(len(seq), 1, -1), self.hidden) outdat = self.hidden2out(lstm_out.view(len(seq),-1)) return outdatmodel = LSTMpred(1,6)loss_function = nn.MSELoss()optimizer = optim.SGD(model.parameters(), lr=0.01)for epoch in range(10): print(epoch) for seq, outs in dat[:700]: seq = ToVariable(seq) outs = ToVariable(outs) #outs = torch.from_numpy(np.array([outs])) optimizer.zero_grad() model.hidden = model.init_hidden() modout = model(seq) loss = loss_function(modout, outs) loss.backward() optimizer.step()predDat = []for seq, trueVal in dat[700:]: seq = ToVariable(seq) trueVal = ToVariable(trueVal) predDat.append(model(seq)[-1].data.numpy()[0])fig = plt.figure()plt.plot(y.numpy())plt.plot(range(700,890),predDat)plt.show()
阅读全文
0 0
- Pytorch LSTM 时间序列预测
- LSTM预测时间序列
- LSTM预测时间序列
- LSTM 时间序列预测 matlab
- Tensorflow LSTM时间序列预测的尝试
- Tensorflow LSTM时间序列预测的尝试
- 推荐lstm时间序列预测论文和资料
- python利用LSTM进行时间序列分析预测
- python利用LSTM进行时间序列分析预测
- Python中利用LSTM模型进行时间序列预测分析
- 用 LSTM 做时间序列预测的一个小例子
- Python中用Keras构建LSTM模型进行时间序列预测
- Python中使用LSTM网络进行时间序列预测
- 基于Keras的LSTM多变量时间序列预测
- 用 LSTM 做时间序列预测的一个小例子
- 使用tensorflow的lstm网络进行时间序列预测
- Python时间序列LSTM预测系列教程(10)-多步预测
- Python时间序列LSTM预测系列教程(11)-多步预测
- 菜鸟上路!
- Spring Data JPA注解@DynamicInsert和@DynamicUpdate
- Unity 移动端 The file 'none' is corrupted. 报错解决
- 对People类分析
- Python CGI编程
- Pytorch LSTM 时间序列预测
- 孤儿进程和僵尸进程
- Placing probes using scripting
- 文章标题
- 直方图均衡化
- 查看浏览器内核版本测试网站
- JZOJsenior3479.【NOIP2013模拟联考9】工作安排(work)
- 前端页面中的iframe框架的实践
- 使用JQuery定时关闭模态框