TensorFlow RNN 教程和代码

来源:互联网 发布:mac lion dmg 编辑:程序博客网 时间:2024/06/05 06:13

分析:
看 TensorFlow 也有一段时间了,准备按照 GitHub 上的教程,敲出来,顺便整理一下思路。
RNN部分
  1. 定义参数,包括数据相关,训练相关。
  2. 定义模型,损失函数,优化函数。
  3. 训练,准备数据,输入数据,输出结果。

代码:

#!/usr/bin/env python# -*- coding: utf-8 -*-import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datafrom tensorflow.contrib import rnnmnist=input_data.read_data_sets("./data",one_hot=True)training_rate=0.001training_iters=100000batch_size=128display_step=10n_input=28n_steps=28n_hidden=128n_classes=10x=tf.placeholder("float",[None,n_steps,n_input])y=tf.placeholder("float",[None,n_classes])weights={'out':tf.Variable(tf.random_normal([n_hidden,n_classes]))}biases={'out':tf.Variable(tf.random_normal([n_classes]))}def RNN(x,weights,biases):   x=tf.unstack(x,n_steps,1)   lstm_cell=rnn.BasicLSTMCell(n_hidden,forget_bias=1.0)   outputs,states=rnn.static_rnn(lstm_cell,x,dtype=tf.float32)   return tf.matmul(outputs[-1],weights['out'])+biases['out']pred=RNN(x,weights,biases)cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))optimizer=tf.train.AdamOptimizer(learning_rate=training_rate).minimize(cost)correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))accuaracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))init=tf.global_variables_initializer()with tf.Session() as sess:   sess.run(init)   step=1   while step*batch_size<training_iters:      batch_x,batch_y=mnist.train.next_batch(batch_size)      batch_x=batch_x.reshape(batch_size,n_steps,n_input)      sess.run(optimizer,feed_dict={x:batch_x,y:batch_y})      if step%display_step==0:         acc=sess.run(accuaracy,feed_dict={x:batch_x,y:batch_y})         loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})         print("Iter " + str(step * batch_size) + ", Minibatch Loss= " + \               "{:.6f}".format(loss) + ", Training Accuracy= " + \               "{:.5f}".format(acc))      step+=1


输出:

/anaconda/bin/python2.7 /Users/xxxx/PycharmProjects/TF_3/tf_rnn.pyExtracting ./data/train-images-idx3-ubyte.gzExtracting ./data/train-labels-idx1-ubyte.gzExtracting ./data/t10k-images-idx3-ubyte.gzExtracting ./data/t10k-labels-idx1-ubyte.gz2017-07-15 16:41:15.125981: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn’t compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.2017-07-15 16:41:15.125994: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn’t compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.2017-07-15 16:41:15.125997: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn’t compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.2017-07-15 16:41:15.126002: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn’t compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.Iter 1280, Minibatch Loss= 1.842738, Training Accuracy= 0.33594Iter 2560, Minibatch Loss= 1.489123, Training Accuracy= 0.50000Iter 3840, Minibatch Loss= 1.300060, Training Accuracy= 0.57812Iter 5120, Minibatch Loss= 1.244872, Training Accuracy= 0.62500Iter 6400, Minibatch Loss= 0.947143, Training Accuracy= 0.71094Iter 7680, Minibatch Loss= 0.709695, Training Accuracy= 0.75781Iter 8960, Minibatch Loss= 0.799844, Training Accuracy= 0.76562Iter 10240, Minibatch Loss= 0.594611, Training Accuracy= 0.83594Iter 11520, Minibatch Loss= 0.529350, Training Accuracy= 0.82031Iter 12800, Minibatch Loss= 0.624426, Training Accuracy= 0.82031Iter 14080, Minibatch Loss= 0.481889, Training Accuracy= 0.82812Iter 15360, Minibatch Loss= 0.449692, Training Accuracy= 0.84375Iter 16640, Minibatch Loss= 0.418820, Training Accuracy= 0.85938Iter 17920, Minibatch Loss= 0.412161, Training Accuracy= 0.85156Iter 19200, Minibatch Loss= 0.256099, Training Accuracy= 0.90625Iter 20480, Minibatch Loss= 0.227309, Training Accuracy= 0.90625Iter 21760, Minibatch Loss= 0.431014, Training Accuracy= 0.85938Iter 23040, Minibatch Loss= 0.377097, Training Accuracy= 0.87500Iter 24320, Minibatch Loss= 0.268153, Training Accuracy= 0.89844Iter 25600, Minibatch Loss= 0.170557, Training Accuracy= 0.95312Iter 26880, Minibatch Loss= 0.286947, Training Accuracy= 0.91406Iter 28160, Minibatch Loss= 0.189623, Training Accuracy= 0.94531Iter 29440, Minibatch Loss= 0.228949, Training Accuracy= 0.95312Iter 30720, Minibatch Loss= 0.157198, Training Accuracy= 0.94531Iter 32000, Minibatch Loss= 0.205744, Training Accuracy= 0.93750Iter 33280, Minibatch Loss= 0.195218, Training Accuracy= 0.92188Iter 34560, Minibatch Loss= 0.177956, Training Accuracy= 0.92969Iter 35840, Minibatch Loss= 0.131563, Training Accuracy= 0.96875Iter 37120, Minibatch Loss= 0.215156, Training Accuracy= 0.92969Iter 38400, Minibatch Loss= 0.232274, Training Accuracy= 0.94531Iter 39680, Minibatch Loss= 0.324053, Training Accuracy= 0.91406Iter 40960, Minibatch Loss= 0.196385, Training Accuracy= 0.93750Iter 42240, Minibatch Loss= 0.151221, Training Accuracy= 0.95312Iter 43520, Minibatch Loss= 0.242021, Training Accuracy= 0.95312Iter 44800, Minibatch Loss= 0.304008, Training Accuracy= 0.90625Iter 46080, Minibatch Loss= 0.185177, Training Accuracy= 0.93750Iter 47360, Minibatch Loss= 0.190960, Training Accuracy= 0.94531Iter 48640, Minibatch Loss= 0.141995, Training Accuracy= 0.94531Iter 49920, Minibatch Loss= 0.199995, Training Accuracy= 0.94531Iter 51200, Minibatch Loss= 0.193773, Training Accuracy= 0.92188Iter 52480, Minibatch Loss= 0.151757, Training Accuracy= 0.94531Iter 53760, Minibatch Loss= 0.153755, Training Accuracy= 0.94531Iter 55040, Minibatch Loss= 0.141472, Training Accuracy= 0.93750Iter 56320, Minibatch Loss= 0.168057, Training Accuracy= 0.96094Iter 57600, Minibatch Loss= 0.135691, Training Accuracy= 0.96094Iter 58880, Minibatch Loss= 0.097003, Training Accuracy= 0.97656Iter 60160, Minibatch Loss= 0.274090, Training Accuracy= 0.92188Iter 61440, Minibatch Loss= 0.147230, Training Accuracy= 0.95312Iter 62720, Minibatch Loss= 0.106019, Training Accuracy= 0.96094Iter 64000, Minibatch Loss= 0.101133, Training Accuracy= 0.97656Iter 65280, Minibatch Loss= 0.169548, Training Accuracy= 0.93750Iter 66560, Minibatch Loss= 0.101966, Training Accuracy= 0.96094Iter 67840, Minibatch Loss= 0.106501, Training Accuracy= 0.96875Iter 69120, Minibatch Loss= 0.082817, Training Accuracy= 0.96875Iter 70400, Minibatch Loss= 0.192926, Training Accuracy= 0.96094Iter 71680, Minibatch Loss= 0.086935, Training Accuracy= 0.96875Iter 72960, Minibatch Loss= 0.052052, Training Accuracy= 0.98438Iter 74240, Minibatch Loss= 0.129968, Training Accuracy= 0.95312Iter 75520, Minibatch Loss= 0.058070, Training Accuracy= 0.99219Iter 76800, Minibatch Loss= 0.089518, Training Accuracy= 0.96875Iter 78080, Minibatch Loss= 0.106092, Training Accuracy= 0.98438Iter 79360, Minibatch Loss= 0.223101, Training Accuracy= 0.92188Iter 80640, Minibatch Loss= 0.069419, Training Accuracy= 0.97656Iter 81920, Minibatch Loss= 0.050585, Training Accuracy= 0.99219Iter 83200, Minibatch Loss= 0.048002, Training Accuracy= 0.98438Iter 84480, Minibatch Loss= 0.094293, Training Accuracy= 0.96875Iter 85760, Minibatch Loss= 0.152253, Training Accuracy= 0.96094Iter 87040, Minibatch Loss= 0.085382, Training Accuracy= 0.97656Iter 88320, Minibatch Loss= 0.147018, Training Accuracy= 0.95312Iter 89600, Minibatch Loss= 0.099780, Training Accuracy= 0.96094Iter 90880, Minibatch Loss= 0.118362, Training Accuracy= 0.93750Iter 92160, Minibatch Loss= 0.110498, Training Accuracy= 0.96094Iter 93440, Minibatch Loss= 0.077664, Training Accuracy= 0.98438Iter 94720, Minibatch Loss= 0.070865, Training Accuracy= 0.96094Iter 96000, Minibatch Loss= 0.156309, Training Accuracy= 0.94531Iter 97280, Minibatch Loss= 0.116825, Training Accuracy= 0.94531Iter 98560, Minibatch Loss= 0.099852, Training Accuracy= 0.96875Iter 99840, Minibatch Loss= 0.116358, Training Accuracy= 0.96875Process finished with exit code 0


原文链接:http://www.tensorflownews.com/2017/07/15/tensorflow-rnn-turorial-mnist-code/