7.TensorFlow的RNN和LSTM基础范例

来源:互联网 发布:数码宝贝网络侦探能力 编辑:程序博客网 时间:2024/06/05 01:08

需要在代码所在路径下新建一个save文件夹,待会模型会保存在这个文件夹中。

import numpy as npimport tensorflow as tffrom tensorflow.python.ops import rnn, rnn_cellclass SeriesPredictor:    def __init__(self, input_dim, seq_size, hidden_dim=10):        # Hyperparameters        self.input_dim = input_dim        self.seq_size = seq_size        self.hidden_dim = hidden_dim        # Weight variables and input placeholders        self.W_out = tf.Variable(tf.random_normal([hidden_dim, 1]), name='W_out')        self.b_out = tf.Variable(tf.random_normal([1]), name='b_out')        self.x = tf.placeholder(tf.float32, [None, seq_size, input_dim])        self.y = tf.placeholder(tf.float32, [None, seq_size])        # Cost optimizer        self.cost = tf.reduce_mean(tf.square(self.model() - self.y))        self.train_op = tf.train.AdamOptimizer().minimize(self.cost)        # Auxiliary ops        self.saver = tf.train.Saver()    def model(self):        """        :param x: inputs of size [T, batch_size, input_size]        :param W: matrix of fully-connected output layer weights        :param b: vector of fully-connected output layer biases        """        cell = rnn_cell.BasicLSTMCell(self.hidden_dim)        outputs, states = rnn.dynamic_rnn(cell, self.x, dtype=tf.float32)        num_examples = tf.shape(self.x)[0]        W_repeated = tf.tile(tf.expand_dims(self.W_out, 0), [num_examples, 1, 1])        out = tf.batch_matmul(outputs, W_repeated) + self.b_out        out = tf.squeeze(out)        return out    def train(self, train_x, train_y):        with tf.Session() as sess:            tf.get_variable_scope().reuse_variables()            sess.run(tf.initialize_all_variables())            for i in range(1000):                _, mse = sess.run([self.train_op, self.cost], feed_dict={self.x: train_x, self.y: train_y})                if i % 100 == 0:                    print(i, mse)            save_path = self.saver.save(sess, 'save/model.ckpt')            print('Model saved to {}'.format(save_path))    def test(self, test_x):        with tf.Session() as sess:            tf.get_variable_scope().reuse_variables()            self.saver.restore(sess, 'save/model.ckpt')            output = sess.run(self.model(), feed_dict={self.x: test_x})            print(output)if __name__ == '__main__':    predictor = SeriesPredictor(input_dim=1, seq_size=4, hidden_dim=10)#预测数据维度为1,预测所需步长为4,    train_x = [[[1], [2], [5], [6]],               [[5], [7], [7], [8]],               [[3], [4], [5], [7]]]    train_y = [[1, 3, 7, 11],               [5, 12, 14, 15],               [3, 7, 9, 12]]    predictor.train(train_x, train_y)    test_x = [[[1], [2], [3], [4]],  # 1, 3, 5, 7              [[4], [5], [6], [7]]]  # 4, 9, 11, 13    predictor.test(test_x)

打印信息如下:

0 79.4345100 49.6515200 23.4903300 11.8844400 6.26166500 4.21735600 2.90793700 2.06972800 1.50332900 1.09277Model saved to save/model.ckpt[[  0.68482286   2.59259105   4.72971964   6.3778677 ] [  4.54114914   9.24342632  11.5808382   12.61558914]]

算法预测的结果是两个值的和,正确的应该为1.3.5.7和4.9.11.13。预测结果比较接近。

原创粉丝点击