tensoflow练习2:利用Recurrent Neural Network 进行分类
来源:互联网 发布:检察院 大数据 公司 编辑:程序博客网 时间:2024/06/03 18:51
快一周没写博客了,总觉得缺少点什么。最近忙着写论文、改论文也没什么时间写。好了,不废话了,直接上干货,利用RNN进行分类。RNN主要用于序列数据的处理,在图像、语音、文本等领域有着广泛的应用。这里使用RNN对手写体进行识别(0-9共10类)。
代码如下:
#coding=utf-8import tensorflow as tfimport numpy as numpyfrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('MNIST_data/',one_hot=True)#下载数据,有点慢#一张图片是28*28,FNN一次将数据输入到网络,RNN将它分成快chunk_size = 28chunk_n = 28rnn_size = 256#隐层尺寸(维度)#输出层n_output_layer = 10 #输出层X = tf.placeholder('float',[None,chunk_n,chunk_size])Y = tf.placeholder('float')def recurrent_neural_network(data): layer = {'w_':tf.Variable(tf.random_normal([rnn_size,n_output_layer])), 'b_':tf.Variable(tf.random_normal([n_output_layer]))} lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size) data = tf.transpose(data,[1,0,2]) data = tf.reshape(data,[-1,chunk_size]) data = tf.split(0,chunk_n,data) outputs, status = tf.nn.rnn(lstm_cell,data,dtype=tf.float32) output = tf.add(tf.matmul(outputs[-1],layer['w_']),layer['b_']) return output#每一次100条数据batch_size = 100def train_neural_network(X,Y): predict = recurrent_neural_network(X) cost_func = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(predict,Y)) optimizer = tf.train.AdamOptimizer().minimize(cost_func) epochs = 13 with tf.Session() as session: session.run(tf.global_variables_initializer()) epoch_loss =0 print('training begins:') for epoch in range(epochs): for i in range( int(mnist.train.num_examples/batch_size)): x,y = mnist.train.next_batch(batch_size) x = x.reshape([batch_size,chunk_n,chunk_size]) _,c =session.run([optimizer,cost_func],feed_dict={X:x,Y:y}) epoch_loss += c print(epoch,':',epoch_loss) correct = tf.equal(tf.argmax(predict,1),tf.argmax(Y,1)) accuracy = tf.reduce_mean(tf.cast(correct,'float')) print('精确率:',accuracy.eval({X:mnist.test.images.reshape(-1,chunk_n,chunk_size),Y:mnist.test.labels}))train_neural_network(X,Y)
接下来,我对上述代码进行分析;
数据准备,就是最上面的下载代码。
(1)定义数据与标签类型;
X = tf.placeholder('float',[None,chunk_n,chunk_size])Y = tf.placeholder('float')n_output_layer = 10 #输出类数
其中,None表示不确定,可随机自定。
(2)定义rnn:
def recurrent_neural_network(data): layer = {'w_':tf.Variable(tf.random_normal([rnn_size,n_output_layer])), 'b_':tf.Variable(tf.random_normal([n_output_layer]))} lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)#定义神经元LSTM data = tf.transpose(data,[1,0,2]) data = tf.reshape(data,[-1,chunk_size]) data = tf.split(0,chunk_n,data) outputs, status = tf.nn.rnn(lstm_cell,data,dtype=tf.float32) output = tf.add(tf.matmul(outputs[-1],layer['w_']),layer['b_']) return output
layer:是对输出进行映射的参数。
比较难理解的就是三行data了。其实不用纠结。这三行的目的是将3维张量转换成chunk_n个形状为batch_size * chunk_size的张量。
outpus[-1]表示最后的隐层表示。
(3)定义网络训练
def train_neural_network(X,Y): predict = recurrent_neural_network(X)#预测 cost_func = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(predict,Y))#损失函数 optimizer = tf.train.AdamOptimizer().minimize(cost_func)#优化器 epochs = 13 with tf.Session() as session: session.run(tf.global_variables_initializer()) epoch_loss =0 print('training begins:') for epoch in range(epochs): for i in range( int(mnist.train.num_examples/batch_size)):#一共多少个batch x,y = mnist.train.next_batch(batch_size) x = x.reshape([batch_size,chunk_n,chunk_size])#重新塑形 _,c =session.run([optimizer,cost_func],feed_dict={X:x,Y:y}) epoch_loss += c print(epoch,':',epoch_loss) correct = tf.equal(tf.argmax(predict,1),tf.argmax(Y,1)) accuracy = tf.reduce_mean(tf.cast(correct,'float')) print('精确率:',accuracy.eval({X:mnist.test.images.reshape(-1,chunk_n,chunk_size),Y:mnist.test.labels}))
X,Y相当于盒子,容器,不要纠结太多。loss为累加损失
(4)输出结果:
阅读全文
0 0
- tensoflow练习2:利用Recurrent Neural Network 进行分类
- Recurrent Neural Network系列2--利用Python,Theano实现RNN
- Recurrent Neural Network (RNN)
- lecture10,Recurrent Neural Network
- CS231N-10-Recurrent Neural Network
- Recurrent Neural Network系列4--利用Python,Theano实现GRU或LSTM
- tensoflow练习3:卷积神经网络用于分类
- Recurrent Neural Network 学习之路
- 回归神经网络RNN(Recurrent Neural network)
- Tensorflow: recurrent neural network (mnist basic)
- Tensorflow: recurrent neural network char-level 0
- Tensorflow: recurrent neural network char-level 1
- 机器学习: Python with Recurrent Neural Network
- 【论文笔记】Recurrent Neural Network Regularization
- 详解循环神经网络(Recurrent Neural Network)
- 详解循环神经网络(Recurrent Neural Network)
- 循环神经网络(Recurrent Neural Network)
- Recurrent Neural Network Language Modeling Toolkit代码学习
- Python常用内建模块—datetime\collections\struct
- 初步认识Spring
- GCD里面的关键字理解
- linux进程利用文件通信
- QWebEngine与JS交互
- tensoflow练习2:利用Recurrent Neural Network 进行分类
- iOS 自定义Cell拖拽的另一种形式
- N皇后问题<经典DFS>
- ORACLE 常用函数——日期/时间函数
- Spring+Hibernate双数据源测试Mysql集群读写分离
- SQL连接的几种方式
- 设计模式-18-备忘录模式
- 虚拟机里部署java web工程
- HDU 2008