使用RNN进行图像分类
来源:互联网 发布:淘宝客鹊桥是什么意思 编辑:程序博客网 时间:2024/05/23 11:23
使用CNN进行图像分类是很稀疏平常的,其实使用RNN也是可以的.
这篇介绍的就是使用RNN(LSTM/GRU)进行mnist的分类,对RNN不太了解的可以看看下面的材料:
1. [LSTM的介绍] http://colah.github.io/posts/2015-08-Understanding-LSTMs/
2. [The Unreasonable Effectiveness of RNNs] http://karpathy.github.io/2015/05/21/rnn-effectiveness/
3. [WildML RNN介绍] http://www.wildml.com/2015/09/recurrent-neural-networks-tutorial-part-1-introduction-to-rnns/
4. [RNN in Tensorflow] http://www.wildml.com/2016/08/rnns-in-tensorflow-a-practical-guide-and-undocumented-features/
基础介绍
如何使用RNN进行mnist的分类呢?其实对应到RNN里面就是个Sequence Classification
问题.
先看下CS231n
中关于RNN部分的一张图:
其实图像的分类对应上图就是个many to one
的问题. 对于mnist来说其图像的size是28*28,如果将其看成28个step,每个step的size是28的话,是不是刚好符合上图. 当我们得到最终的输出的时候将其做一次线性变换就可以加softmax来分类了,其实挺简单的.
具体实现
tf中RNN有很多的变体,最出名也是最常用的就是: LSTM
和GRU
,其它的还有向GridLSTM
、AttentionCell
等,要查看最新tf支持的RNN类型,基本只要关注这两个文件就可以了:
1. [rnn_cell.py] https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell.py
2. [contrib/rnn_cell.py] https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/ops/rnn_cell.py
对于常见的RNN cell的使用总结:
获取数据
很简单,tf自带都帮我们写好了,直接调用就行了.
import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist_data = input_data.read_data_sets('data/mnist', one_hot=True)
如何不存在data/mnist
这个目录,其会自己下载mnist数据,要是你的网络不行也可以自己去mnist的网站下载然后将数据放在目录下就可以了.
tf贴心到什么程度呢?连batch generator都帮我们写好了,直接用next_batch
就可以获得下一个batch的数据.
train_x, train_y = mnist_data.train.images, mnist_data.train.labelstest_x, test_y = mnist_data.test.images, mnist_data.test.labelsbatch_x, batch_y = mnist.train.next_batch(batch_size)
training examples是55000, test examples是10000,validation examples是5000.
定义网络
我们使用3层的GRU
,hidden units
是200的带dropout
的RNN来作为mnist分类的网络,具体代码如下:
cells = list()for _ in range(num_layers): cell = tf.nn.rnn_cell.GRUCell(num_units=num_hidden) cell = tf.nn.rnn_cell.DropoutWrapper(cell=cell, output_keep_prob=1.0-dropout) cells.append(cell)network = tf.nn.rnn_cell.MultiRNNCell(cells=cells)outputs, last_state = tf.nn.dynamic_rnn(cell=network, inputs=data, dtype=tf.float32)# get last outputoutputs = tf.transpose(outputs, (1, 0, 2))last_output = tf.gather(outputs, int(outputs.get_shape()[0])-1)# linear transformout_size = int(target.get_shape()[1])weight, bias = initialize_weight_bias(in_size=num_hidden, out_size=out_size)logits = tf.add(tf.matmul(last_output, weight), bias)return logits
因为mnist太简单,这个简单的网络其实已经可以搞定mnist的分类问题,后期的test acc可以到0.985(within 3 epoches).
训练和测试
分类嘛,还是使用cross entropy
作为loss,然后计算下错误率是多少,代码如下:
batch_size = 64, lr = 0.001
# placeholdersinput_x = tf.placeholder(tf.float32, shape=(None, 28, 28))input_y = tf.placeholder(tf.float32, shape=(None, 10))dropout = tf.placeholder(tf.float32)input_logits = model(input_x, input_y, dropout)# loss and error rate oploss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=input_logits, labels=input_y))train_op = tf.train.RMSPropOptimizer(0.001).minimize(loss)input_prob = tf.nn.softmax(input_logits)error_count = tf.not_equal(tf.arg_max(input_prob, 1), tf.arg_max(input_y, 1))error_rate_op = tf.reduce_mean(tf.cast(error_count, tf.float32))
input_x
和input_y
表示输入的image和label,model
就是上面定义的3层GRU模型;可以使用tf.summary来使用tensorboard
查看训练时的error rate
和loss
等信息.
训练代码:
for step in range(total_steps): train_x, train_y = mnist_data.train.next_batch(default_batch_size) train_x = train_x.reshape(-1, 28, 28) feed_dict = {input_x: train_x, input_y: train_y, dropout: default_dropout} _, summary = session.run([train_op, merge_summary_op], feed_dict=feed_dict) # write logs summary_writer.add_summary(summary, global_step=epoch*total_steps+step)
测试代码:
# testif step > 0 and (step % test_freq == 0): avg_error = 0 for test_step in range(total_test_steps): test_x, test_y = mnist_data.test.next_batch(default_batch_size) test_x = test_x.reshape(-1, 28, 28) feed_dict = {input_x: test_x, input_y: test_y, dropout: 0} test_error = session.run(error_rate_op, feed_dict=feed_dict) avg_error += test_error / total_test_steps print('epoch: %d, steps: %d, avg_test_error: %.4f' % (epoch, step, avg_error))
结果
训练时的loss和error_rate:
测试的error_rate:
我只跑了3个epoch,错误率基本降低到1.5%左右,亦即正确率在98.5%左右,多跑几个epoch可能错误率还能继续降低,不过对于我们这个demo来说已经够了.
代码我上传在 http://download.csdn.net/download/gavin__zhou/10154583,有需要的可以下载.
- 使用RNN进行图像分类
- [深度学习框架] Keras上使用RNN进行mnist分类
- 使用Keras进行图像分类
- 使用R语言进行图像分类
- 使用TensorFlow-Slim进行图像分类
- Linux下使用caffe进行图像分类
- cnn、rnn相结合进行文本分类
- 使用Keras面向小数据集进行图像分类
- 使用词袋模型对图像进行分类
- 使用Keras预训练模型ResNet50进行图像分类
- 使用预训练模型对图像进行分类
- HOG+SVM进行图像分类
- 使用libsvm进行分类
- 使用决策树进行分类
- RNN(LSTM)用于分类
- Tensorflow-rnn(mnist分类)
- K-means对图像进行分类
- 利用CNN进行图像分类学习笔记
- 购物车管理模块
- 8.5调用函数与数组取负值结束
- recyclerview实现多条目
- [atcoder] agc86 D
- java多线程3-线程的同步与数据传递
- 使用RNN进行图像分类
- PPT 之神器 SmartArt
- android从放弃到精通 第七天 tomorrow
- Uinty学习概述
- 【TensorFlow】神经网络参数与变量(四)
- 笔记
- 【赠书】拨云见日
- 【演讲实录】银行PB级别海量非结构化数据管理实践
- 【社招持续篇】云和恩墨虚位以待,你来不来!