7.TensorFlow的RNN和LSTM基础范例
来源:互联网 发布:数码宝贝网络侦探能力 编辑:程序博客网 时间:2024/06/05 01:08
需要在代码所在路径下新建一个save文件夹,待会模型会保存在这个文件夹中。
import numpy as npimport tensorflow as tffrom tensorflow.python.ops import rnn, rnn_cellclass SeriesPredictor: def __init__(self, input_dim, seq_size, hidden_dim=10): # Hyperparameters self.input_dim = input_dim self.seq_size = seq_size self.hidden_dim = hidden_dim # Weight variables and input placeholders self.W_out = tf.Variable(tf.random_normal([hidden_dim, 1]), name='W_out') self.b_out = tf.Variable(tf.random_normal([1]), name='b_out') self.x = tf.placeholder(tf.float32, [None, seq_size, input_dim]) self.y = tf.placeholder(tf.float32, [None, seq_size]) # Cost optimizer self.cost = tf.reduce_mean(tf.square(self.model() - self.y)) self.train_op = tf.train.AdamOptimizer().minimize(self.cost) # Auxiliary ops self.saver = tf.train.Saver() def model(self): """ :param x: inputs of size [T, batch_size, input_size] :param W: matrix of fully-connected output layer weights :param b: vector of fully-connected output layer biases """ cell = rnn_cell.BasicLSTMCell(self.hidden_dim) outputs, states = rnn.dynamic_rnn(cell, self.x, dtype=tf.float32) num_examples = tf.shape(self.x)[0] W_repeated = tf.tile(tf.expand_dims(self.W_out, 0), [num_examples, 1, 1]) out = tf.batch_matmul(outputs, W_repeated) + self.b_out out = tf.squeeze(out) return out def train(self, train_x, train_y): with tf.Session() as sess: tf.get_variable_scope().reuse_variables() sess.run(tf.initialize_all_variables()) for i in range(1000): _, mse = sess.run([self.train_op, self.cost], feed_dict={self.x: train_x, self.y: train_y}) if i % 100 == 0: print(i, mse) save_path = self.saver.save(sess, 'save/model.ckpt') print('Model saved to {}'.format(save_path)) def test(self, test_x): with tf.Session() as sess: tf.get_variable_scope().reuse_variables() self.saver.restore(sess, 'save/model.ckpt') output = sess.run(self.model(), feed_dict={self.x: test_x}) print(output)if __name__ == '__main__': predictor = SeriesPredictor(input_dim=1, seq_size=4, hidden_dim=10)#预测数据维度为1,预测所需步长为4, train_x = [[[1], [2], [5], [6]], [[5], [7], [7], [8]], [[3], [4], [5], [7]]] train_y = [[1, 3, 7, 11], [5, 12, 14, 15], [3, 7, 9, 12]] predictor.train(train_x, train_y) test_x = [[[1], [2], [3], [4]], # 1, 3, 5, 7 [[4], [5], [6], [7]]] # 4, 9, 11, 13 predictor.test(test_x)
打印信息如下:
0 79.4345100 49.6515200 23.4903300 11.8844400 6.26166500 4.21735600 2.90793700 2.06972800 1.50332900 1.09277Model saved to save/model.ckpt[[ 0.68482286 2.59259105 4.72971964 6.3778677 ] [ 4.54114914 9.24342632 11.5808382 12.61558914]]
算法预测的结果是两个值的和,正确的应该为1.3.5.7和4.9.11.13。预测结果比较接近。
阅读全文
0 0
- 7.TensorFlow的RNN和LSTM基础范例
- 学习Tensorflow的LSTM的RNN例子
- TensorFlow MNIST RNN LSTM
- Tensorflow-LSTM RNN 例子
- 基于tensorflow的RNN-LSTM(一)实现RNN
- Tensorflow: RNN/LSTM gradient clipping
- tensorflow RNN LSTM语言模型
- 分享关于RNN和LSTM的资源
- RNN和LSTM资料
- RNN和LSTM
- RNN和LSTM
- RNN和LSTM
- RNN和LSTM网络
- rnn和lstm
- 用sklearn和tensorflow做boston房价的回归计算的比较(3)--RNN之递归神经网路LSTM
- LSTM和RNN 入门tutorials
- RNN和LSTM原理推导
- RNN以及LSTM的介绍和公式梳理
- Django系列7---cookie、session、json、csrf_token
- spring之applicationContext、beanFactory
- laravel 图片验证码 mews/captcha
- 【Sort】56. Merge Intervals
- 在C#中怎么样声明全局变量和宏定义
- 7.TensorFlow的RNN和LSTM基础范例
- tp5模型笔记---多对多
- 不通过中间变量 交换两个数的值
- 第4章 远程管理
- 10个React小模式
- RemoveDuplicates form Sort Array
- Fluent Python读后感
- LYK快跑!(run)
- 【实战】4-6 git初始化