Python时间序列LSTM预测系列教程(9)-多变量
来源:互联网 发布:染发剂 知乎 编辑:程序博客网 时间:2024/05/21 17:48
多变量LSTM预测模型(3)
教程原文链接
前置教程:
Python时间序列LSTM预测系列教程(7)-多变量
Python时间序列LSTM预测系列教程(8)-多变量
定义&训练模型
1、数据划分成训练和测试数据
本教程用第一年数据做训练,剩余4年数据做评估
2、输入=1时间步长,8个feature
3、第一层隐藏层节点=50,输出节点=1
4、用平均绝对误差MAE做损失函数、Adam的随机梯度下降做优化
5、epoch=50, batch_size=72
模型评估
1、预测后需要做逆缩放
2、用RMSE做评估
代码解析
# coding=utf-8 from math import sqrtfrom numpy import concatenatefrom matplotlib import pyplotfrom pandas import read_csvfrom pandas import DataFramefrom pandas import concatfrom sklearn.preprocessing import MinMaxScalerfrom sklearn.preprocessing import LabelEncoderfrom sklearn.metrics import mean_squared_errorfrom keras.models import Sequentialfrom keras.layers import Densefrom keras.layers import LSTM #转成有监督数据def series_to_supervised(data, n_in=1, n_out=1, dropnan=True): n_vars = 1 if type(data) is list else data.shape[1] df = DataFrame(data) cols, names = list(), list() #数据序列(也将就是input) for i in range(n_in, 0, -1): cols.append(df.shift(i)) names+=[('var%d(t-%d)'%(j+1, i)) for j in range(n_vars)] #预测数据(input对应的输出值) for i in range(0, n_out, 1): cols.append(df.shift(-i)) if i==0: names+=[('var%d(t)'%(j+1)) for j in range(n_vars)] else: names+=[('var%d(t+%d))'%(j+1, i)) for j in range(n_vars)] #拼接 agg = concat(cols, axis=1) if dropnan: agg.dropna(inplace=True) return agg #数据预处理#--------------------------dataset = read_csv('data_set/air_pollution_new.csv', header=0, index_col=0)values = dataset.values #标签编码encoder = LabelEncoder()values[:,4] = encoder.fit_transform(values[:,4])#保证为floatvalues = values.astype('float32')#归一化 scaler = MinMaxScaler(feature_range=(0,1))scaled = scaler.fit_transform(values)#转成有监督数据reframed = series_to_supervised(scaled, 1, 1)#删除不预测的列reframed.drop(reframed.columns[9:16], axis=1, inplace=True)print reframed.head() #数据准备#--------------------------values = reframed.valuesn_train_hours = 365*24 #拿一年的时间长度训练#划分训练数据和测试数据train = values[:n_train_hours, :]test = values[n_train_hours:, :]#拆分输入输出train_x, train_y = train[:, :-1], train[:, -1]test_x, test_y = test[:, :-1], test[:, -1]#reshape输入为LSTM的输入格式train_x = train_x.reshape((train_x.shape[0], 1, train_x.shape[1]))test_x = test_x.reshape((test_x.shape[0], 1, test_x.shape[1]))print 'train_x.shape, train_y.shape, test_x.shape, test_y.shape'print train_x.shape, train_y.shape, test_x.shape, test_y.shape #模型定义#-------------------------model = Sequential()model.add(LSTM(50, input_shape=(train_x.shape[1], train_x.shape[2])))model.add(Dense(1))model.compile(loss='mae', optimizer='adam') #模型训练#------------------------history = model.fit(train_x, train_y, epochs=50, batch_size=72, validation_data=(test_x, test_y), verbose=2, shuffle=False) #输出pyplot.plot(history.history['loss'], label='train')pyplot.plot(history.history['val_loss'], label='test')pyplot.legend()pyplot.show() #预测#------------------------yhat = model.predict(test_x)test_x = test_x.reshape(test_x.shape[0], test_x.shape[2])#预测数据逆缩放inv_yhat = concatenate((yhat, test_x[:, 1:]), axis=1) inv_yhat = scaler.inverse_transform(inv_yhat)inv_yhat = inv_yhat[:, 0]#真实数据逆缩放test_y = test_y.reshape(len(test_y), 1)inv_y = concatenate((test_y, test_x[:, 1:]), axis=1)inv_y = scaler.inverse_transform(inv_y)inv_y = inv_y[:, 0]#计算rmsermse = sqrt(mean_squared_error(inv_y, inv_yhat))print 'Test RMSE:%.3f'%rmse
阅读全文
0 0
- Python时间序列LSTM预测系列教程(9)-多变量
- Python时间序列LSTM预测系列教程(7)-多变量
- Python时间序列LSTM预测系列教程(8)-多变量
- 基于Keras的LSTM多变量时间序列预测
- 教你搭建多变量时间序列预测模型LSTM(附代码、数据集)
- Python时间序列LSTM预测系列教程(10)-多步预测
- Python时间序列LSTM预测系列教程(11)-多步预测
- Python时间序列LSTM预测系列教程(1)-单变量
- Python时间序列LSTM预测系列教程(2)-单变量
- Python时间序列LSTM预测系列教程(3)-单变量
- Python时间序列LSTM预测系列教程(4)-单变量
- Python时间序列LSTM预测系列教程(5)-单变量
- Python时间序列LSTM预测系列教程(6)-单变量
- 代码干货 | 基于Keras的LSTM多变量时间序列预测
- LSTM预测时间序列
- LSTM预测时间序列
- LSTM 时间序列预测 matlab
- Pytorch LSTM 时间序列预测
- mysql 截取字符串: left() ,right() 和IFNULL()用法
- 新人报道!
- UGUI
- 百度console招聘信息
- activiti学习--12 个人任务及三种分配方式:直接设置代理人+流程变量设置代理人+实现类的方式设置代理人+将任务代理人设置为别人
- Python时间序列LSTM预测系列教程(9)-多变量
- java---implements
- 微信小程序----组件之input
- 初探IntelliJ IDEA
- Linux学习-shell(五)
- redis实战:redis自动备份与备份管理
- 模型评估-交叉验证与自助法
- RecyclerView复选框/各种展示/分割线
- Linux环境编程