TensorFlow学习笔记(四):手写数字识别之LSTM网络
来源:互联网 发布:java多线程挂起和阻塞 编辑:程序博客网 时间:2024/05/29 15:39
长短期记忆(LSTM)是目前循环神经网络最普遍使用的类型,在处理时间序列数据时使用最为频繁。
在 TensorFlow 中,基础的 LSTM 单元格声明为:tf.contrib.rnn.BasicLSTMCell ( num_units ),其中 num_units 指一个 LSTM 单元格中的单元数,相当于前馈神经网络中的隐藏层神经元个数,前馈神经网络的隐藏层的节点数量等于每一个时间步中一个 LSTM 单元格内 LSTM 单元的 num_units 数量。
在 TensorFlow 中最简单的 RNN 形式是 static_rnn,在 TensorFlow 中定义为:tf.static_rnn ( cell, inputs ),其中 inputs 接受 shape为 [batch_size,input_size] 的张量列表,列表中每一个元素都分别对应网络展开的时间步。
比如 28 x 28 的图像,将网络按 28 个时间步展开,则在每一个时间步中,可以输入一行 28(input_size) 个像素,经过 28 个时间步输入整张图像。给定图像的 batch_size 值,则每一个时间步将分别收到 batch_size 个图像。由 static_rnn 生成的输出是一个形态为 [batch_size,n_hidden] 的张量列表。列表的长度为将网络展开后的时间步数,即每一个时间步输出一个张量。
具体实现
import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("/tmp/data/", one_hot=True)
time_steps=28
num_units=128
n_input=28
learning_rate=0.001
n_classes=10
batch_size=128
x = tf.placeholder ( "float", [None, time_steps, n_input] )
y = tf.placeholder ( "float", [None, n_classes] )
w = tf.Variable ( tf.random_normal ( [num_units, n_classes] ) )
b = tf.Variable ( tf.random_normal ( [n_classes] ) )
#将 shape 为 [batch_size, time_steps, n_input] 的输入转换成,长度为 time_steps 的 shape 为[batch_size, n_inputs] 的张量列表,再输入到 static_rnn。
input = tf.unstack ( x , time_steps, 1 )
#定义LSTM网络
lstm_layer = rnn.BasicLSTMCell ( num_units, forget_bias=1 )
outputs, _ = rnn.static_rnn ( lstm_layer, input, dtype="float32" )
prediction = tf.matmul ( outputs[-1], w) + b
loss = tf.reduce_mean ( tf.nn.softmax_cross_entropy_with_logits ( logits=prediction, labels=y ) )
op = tf.train.AdamOptimizer ( learning_rate=learning_rate ).minimize ( loss )
correct_prediction = tf.equal ( tf.argmax ( prediction, 1 ), tf.argmax ( y, 1 ) )
accuracy = tf.reduce_mean ( tf.cast ( correct_prediction, tf.float32 ) )
#running
init = tf.global_variables_initializer( )
withtf.Session() as sess:
sess.run ( init )
iter=1
while iter<800:
batch_x,batch_y=mnist.train.next_batch(batch_size=batch_size)
batch_x=batch_x.reshape((batch_size,time_steps,n_input))
sess.run(opt,feed_dict={x:batch_x,y:batch_y})
iter=iter+1
#calculating test accuracy
test_data = mnist.test.images[:128].reshape ( (-1, time_steps, n_input ) )
test_label = mnist.test.labels[:128]
print ( "Testing Accuracy:",sess.run ( accuracy,feed_dict={x:test_data,y:test_label} ) )
最终准确率为:
Testing Accuracy: 99.21%。
- TensorFlow学习笔记(四):手写数字识别之LSTM网络
- TensorFlow学习笔记(一):手写数字识别之softmax回归
- tensorflow进行MNIST手写数字识别-LSTM
- 深度学习四:tensorflow-使用卷积神经网络识别手写数字
- TensorFlow学习笔记(3)----CNN识别MNIST手写数字
- TensorFlow学习笔记(二)MNIST手写数字识别
- TensorFlow学习笔记之源码分析(2)----手写数字识别mnist example
- tensorflow笔记(四)之MNIST手写识别系列一
- tensorflow笔记(四)之MNIST手写识别系列一
- tensorflow 学习笔记12 循环神经网络RNN LSTM结构实现MNIST手写识别
- TensorFlow笔记之一:MNIST手写数字识别
- TensorFlow学习---实现mnist手写数字识别
- tensorflow识别手写数字
- Tensorflow手写数字识别
- tensorflow实战之四:MNIST手写数字识别的优化3-过拟合
- tensorflow-mnist手写数字识别
- TensorFlow实现识别手写数字
- TensorFlow学习笔记(3)--实现Softmax逻辑回归识别手写数字(MNIST数据集)
- HDU 5936 Difference(思维+二分)——2016年中国大学生程序设计竞赛(杭州)
- STM32(二)之GPIO操作(1)——之输入输出操作
- oracle12C--EXECUTE IMMEDIATE语句(61)
- 一个小时写一个简单的iOS新闻应用
- Web31 懒加载
- TensorFlow学习笔记(四):手写数字识别之LSTM网络
- Java程序员秋招面经大合集(BAT美团网易小米华为中兴等)
- Ubuntu 14.04 安装搜狗拼音
- C#无边框控制窗体移动
- redis-set扩展命令
- Zongjie
- 解决Ubuntu 14.04 built-in display 分辨率较低的方法
- hdu1102之prim(堆优化)解法
- Keras入门(3)——磨刀不误砍柴工