TensorFlow学习笔记(四):手写数字识别之LSTM网络

来源:互联网 发布:java多线程挂起和阻塞 编辑:程序博客网 时间:2024/05/29 15:39

长短期记忆(LSTM)是目前循环神经网络最普遍使用的类型,在处理时间序列数据时使用最为频繁。

在 TensorFlow 中,基础的 LSTM 单元格声明为:tf.contrib.rnn.BasicLSTMCell num_units ),其中 num_units 指一个 LSTM 单元格中的单元数,相当于前馈神经网络中的隐藏层神经元个数,前馈神经网络的隐藏层的节点数量等于每一个时间步中一个 LSTM 单元格内 LSTM 单元的 num_units 数量。

在 TensorFlow 中最简单的 RNN 形式是 static_rnn,在 TensorFlow 中定义为:tf.static_rnn cellinputs ),其中 inputs 接受 shape为 [batch_size,input_size] 的张量列表,列表中每一个元素都分别对应网络展开的时间步。

比如 28 x 28 的图像,将网络按 28 个时间步展开,则在每一个时间步中,可以输入一行 28(input_size) 个像素,经过 28 个时间步输入整张图像。给定图像的 batch_size 值,则每一个时间步将分别收到 batch_size 个图像。由 static_rnn 生成的输出是一个形态为 [batch_size,n_hidden] 的张量列表。列表的长度为将网络展开后的时间步数,即每一个时间步输出一个张量。

具体实现

import tensorflow as tf

  1. from tensorflow.contrib import rnn

  2. from tensorflow.examples.tutorials.mnist  import input_data

  3. mnist=input_data.read_data_sets("/tmp/data/"one_hot=True)

  4. time_steps=28

  5. num_units=128

  6. n_input=28

  7. learning_rate=0.001

  8. n_classes=10

  9. batch_size=128

tf.placeholder "float", [Nonetime_stepsn_input] )

    1. tf.placeholder  "float",  [Nonen_classes] )

  1. tf.Variable tf.random_normal ( [num_units,  n_classes] ) )

  2. tf.Variable tf.random_normal ( [n_classes] ) )

#将 shape 为 [batch_size, time_steps, n_input] 的输入转换成,长度为 time_steps  的 shape 为[batch_size, n_inputs] 的张量列表,再输入到 static_rnn。

  1. input tf.unstack time_steps)

#定义LSTM网络

  1. lstm_layer rnn.BasicLSTMCell num_unitsforget_bias=)

  2. outputsrnn.static_rnn lstm_layerinputdtype="float32" )

  1. prediction tf.matmul outputs[-1], w) + b

loss tf.reduce_mean tf.nn.softmax_cross_entropy_with_logits logits=predictionlabels=) )

  1. op tf.train.AdamOptimizer learning_rate=learning_rate ).minimize loss )

  2. correct_prediction tf.equal tf.argmax prediction), tf.argmax y) )

  3. accuracy tf.reduce_mean tf.cast correct_predictiontf.float32 ) )

  1. #running

  2. init tf.global_variables_initializer( )

  3. withtf.Session() as sess:

  4.     sess.run init 

  5.     iter=1

  6.     while iter<800:

  7.         batch_x,batch_y=mnist.train.next_batch(batch_size=batch_size)

  8.         batch_x=batch_x.reshape((batch_size,time_steps,n_input))

  9.         sess.run(opt,feed_dict={x:batch_x,y:batch_y})

  10.         iter=iter+1

        #calculating test accuracy

  1.         test_data mnist.test.images[:128].reshape ( (-1time_stepsn_input ) )

  2.         test_label mnist.test.labels[:128]

  3.         print "Testing Accuracy:",sess.run accuracy,feed_dict={x:test_data,y:test_label} ) )

最终准确率为:

Testing Accuracy: 99.21%。

阅读全文
0 0