机器学习笔记8:基于TensorFlow的数据预测
来源:互联网 发布:cf游戏数据异常怎么办 编辑:程序博客网 时间:2024/05/19 20:22
机器学习笔记8:基于TensorFlow的数据预测
本文是在一篇博客预测天朝铁路的客运量一文中学习,代码部分引用该文,对其进行部分修改。
时间序列数据是指在不同时间点上收集到的数据,这类数据反映了某一事物、现象等随时间的变化状态或程度。
铁路客运量历史数据
铁路客运量.csv(2005-2016月度数据)
使用matplotlib画出数据走势
import matplotlib.pyplot as plt import pandas as pd import requests import io import numpy as np url = 'http://blog.topspeedsnail.com/wp-content/uploads/2016/12/铁路客运量.csv' ass_data = requests.get(url).content df = pd.read_csv(io.StringIO(ass_data.decode('utf-8'))) # python2使用StringIO.StringIO data = np.array(df['铁路客运量_当期值(万人)']) # normalize normalized_data = (data - np.mean(data)) / np.std(data) plt.figure() plt.plot(data) plt.show()
利用TensorFlow进行预测
# coding=utf-8'''Author:Chen haoDescription: counterDate: August 22 , 2017'''import numpy as npimport tensorflow as tfimport matplotlib.pyplot as pltimport pandas as pdimport requestsimport io# 加载数据url = 'http://blog.topspeedsnail.com/wp-content/uploads/2016/12/铁路客运量.csv'ass_data = requests.get(url).contentdf = pd.read_csv(io.StringIO(ass_data.decode('utf-8'))) # python2使用StringIO.StringIOdata = np.array(df['铁路客运量_当期值(万人)'])# normalizenormalized_data = (data - np.mean(data)) / np.std(data)seq_size = 3train_x, train_y = [], []for i in range(len(normalized_data) - seq_size - 1): train_x.append(np.expand_dims(normalized_data[i: i + seq_size], axis=1).tolist()) train_y.append(normalized_data[i + 1: i + seq_size + 1].tolist())input_dim = 1X = tf.placeholder(tf.float32, [None, seq_size, input_dim])Y = tf.placeholder(tf.float32, [None, seq_size])# regressiondef ass_rnn(hidden_layer_size=6): W = tf.Variable(tf.random_normal([hidden_layer_size, 1]), name='W') b = tf.Variable(tf.random_normal([1]), name='b') cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_layer_size) outputs, states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32) W_repeated = tf.tile(tf.expand_dims(W, 0), [tf.shape(X)[0], 1, 1]) out = tf.matmul(outputs, W_repeated) + b out = tf.squeeze(out) return outdef train_rnn(): out = ass_rnn() loss = tf.reduce_mean(tf.square(out - Y)) train_op = tf.train.AdamOptimizer(learning_rate=0.003).minimize(loss) saver = tf.train.Saver(tf.global_variables()) with tf.Session() as sess: # tf.get_variable_scope().reuse_variables() sess.run(tf.global_variables_initializer()) for step in range(9000): _, loss_ = sess.run([train_op, loss], feed_dict={X: train_x, Y: train_y}) if step % 10 == 0: # 用测试数据评估loss print(step, loss_) print("保存模型: ", saver.save(sess, 'ass.model'))def prediction(): out = ass_rnn() saver = tf.train.Saver(tf.global_variables()) with tf.Session() as sess: # tf.get_variable_scope().reuse_variables() saver.restore(sess, './ass.model') prev_seq = train_x[-1] predict = [] for i in range(12): next_seq = sess.run(out, feed_dict={X: [prev_seq]}) predict.append(next_seq[-1]) prev_seq = np.vstack((prev_seq[1:], next_seq[-1])) plt.figure() plt.plot(list(range(len(normalized_data))), normalized_data, color='b') plt.plot(list(range(len(normalized_data), len(normalized_data) + len(predict))), predict, color='r') plt.show()#train_rnn()prediction()
需要注意的是,先要运行train_rnn()函数,然后将训练产生的模型保存到当前的文件夹(屏蔽prediction函数),然后在运行prediction()函数(屏蔽train_rnn()函数)即可获得预测的结果。
运行的结果如下图所示:
运行构建的模型在很多情况下并不理想。
阅读全文
0 0
- 机器学习笔记8:基于TensorFlow的数据预测
- 机器学习笔记1:基于Logistic回归进行数据预测
- 机器学习第一个练手程序 基于决策树的iris数据预测
- 基于Tensorflow的MNIST机器学习
- 机器学习实战笔记-预测数值型数据:回归
- 【机器学习sklearn】基于sklearn的股票预测
- 机器学习实践1:基于logistic regression的性别预测
- 机器学习初学者的TensorFlow笔记
- 关于机器学习在线预测的任务学习笔记
- 【机器学习】Tensorflow学习笔记
- 机器学习实战-8预测数值型数据-回归
- 基于TensorFlow的机器学习(1) -- 基础介绍
- 基于TensorFlow的机器学习(2) -- 回归模型
- 基于Tensorflow的机器学习(4) -- 随机森林
- 基于Tensorflow的机器学习(5) -- 全连接神经网络
- 基于Tensorflow的机器学习(6) -- 卷积神经网络
- 用TensorFlow的Linear/DNNRegrressor预测数据
- 基于大数据学习算法的优惠券预测模型
- Java——代码块
- Linux字符设备驱动
- DEDECMS点击主栏目默认显示第一个子栏目列表的方法
- 利用maven-shade-plugin打包包含所有依赖jar包
- poj-3982(矩阵快速幂+大数模板)
- 机器学习笔记8:基于TensorFlow的数据预测
- 校园招聘-2017滴滴研发工程师内推笔试编程题
- CURL:Protocol http not supported or disabled in libcurl
- JS正则表达式
- php7中使用mongoDB的聚合操作对数据进行分组求和统计操作
- RDD、DataFrame、Dataset介绍
- 【大二最后两题】Hrbust 2064 萌萌哒十五酱的宠物~【思维+树链剖分 / 树上倍增LCA】
- python学习必知---python2.x与python3.x选择
- (六)初始化并设置event