用RNN做MNIST分类
来源:互联网 发布:js iframe 跳转 编辑:程序博客网 时间:2024/05/11 16:33
1.前言
RNN常用作NLP中,像图片生成文字、自动生成古诗词等。这篇文章用RNN做MNIST手写数字识别,分类效果虽然没有CNN效果好,但准确率也能够达到96%。
2.环境
Mac os系统,python:3.5,Keras
3.代码实现
import numpy as npnp.random.seed(1337) from keras.datasets import mnistfrom keras.utils import np_utilsfrom keras.models import Sequentialfrom keras.layers import SimpleRNN, Activation, Densefrom keras.optimizers import AdamTIME_STEPS = 28 INPUT_SIZE = 28 BATCH_SIZE = 50BATCH_INDEX = 0OUTPUT_SIZE = 10CELL_SIZE = 50LR = 0.001(X_train, y_train), (X_test, y_test) = mnist.load_data()# data pre-processingX_train = X_train.reshape(-1, 28, 28) / 255. # normalizeX_test = X_test.reshape(-1, 28, 28) / 255. # normalizey_train = np_utils.to_categorical(y_train, num_classes=10)y_test = np_utils.to_categorical(y_test, num_classes=10)# build RNN modelmodel = Sequential()# RNN cellmodel.add(SimpleRNN( batch_input_shape=(None, TIME_STEPS, INPUT_SIZE), output_dim=CELL_SIZE, unroll=True,))# output layermodel.add(Dense(OUTPUT_SIZE))model.add(Activation('softmax'))# optimizeradam = Adam(LR)model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])# trainingfor step in range(40001): X_batch = X_train[BATCH_INDEX: BATCH_INDEX+BATCH_SIZE, :, :] Y_batch = y_train[BATCH_INDEX: BATCH_INDEX+BATCH_SIZE, :] cost = model.train_on_batch(X_batch, Y_batch) BATCH_INDEX += BATCH_SIZE BATCH_INDEX = 0 if BATCH_INDEX >= X_train.shape[0] else BATCH_INDEX if step % 500 == 0: cost, accuracy = model.evaluate(X_test, y_test, batch_size=y_test.shape[0], verbose=False) print('test cost: ', cost, 'test accuracy: ', accuracy)
4.结果
阅读全文
2 0
- 用RNN做MNIST分类
- Tensorflow-rnn(mnist分类)
- tensorflow利用RNN和双向RNN实现MNIST分类问题
- Tensorflow学习: RNN-LSTM应用于MNIST数据分类
- [深度学习框架] Keras上使用RNN进行mnist分类
- RNN实践一:LSTM实现MNIST数字分类
- 单向RNN和双向RNN在mnist数据集上的分类实验
- TensorFlow MNIST RNN LSTM
- 用rnn做文本生成
- 利用RNN做脑电信号的分类(一)
- MNIST分类
- 使用keras对mnist数据集做分类
- RNN(LSTM)用于分类
- tensorflow构建RNN识别mnist手写数字
- tf8.mnist分类学习
- TensorFlow实现 mnist分类
- RNN用于二值分类
- 使用RNN进行图像分类
- FTPrep, 3 Longest Substring Without Repeating Characters
- iMindMap各种视图介绍
- Django-信号Signals
- CentOS 7配置静态IP步骤
- 第四季1.Java中的泛型
- 用RNN做MNIST分类
- POJ 1986 Distance Queries
- Ubuntu挂载新硬盘
- \r与\n的区别,以及\r\n的用法
- Linux Terminal(终端快捷键)
- 两年成长,甘苦自知
- 【笔记】 WebService之自定义拦截器
- RecyclerView完全解析,让你从此爱上RecyclerView
- Vim技能修炼教程(4)