tensorflow-mnist手写数字识别
来源:互联网 发布:臭氧灯能除螨虫知乎 编辑:程序博客网 时间:2024/04/29 19:49
mnist数据集
数据集简介
mnist手写数字数据集分为训练集和测试集,其中训练集有60000张图片,测试集有10000张图片。每张图片都是灰度图像,像素取值范围从0~255,图片大小为28×28,并且每张图片都对应0~9中的一个数字。
更多关于mnist手写数字数据集的介绍,点这
图片例子如下:
图像质量问题
数据集中绝大部分图像都可以很轻松地辨别出是哪个数字,但仍存在一小部分图像很难分辨出数字。
图片例子如下:
显示图片
'''tensorflow version: 1.0.0'''from tensorflow.examples.tutorials.mnist import input_data# 以下是导入的input_data.py的代码from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport gzipimport osimport tempfileimport numpyfrom six.moves import urllibfrom six.moves import xrangeimport tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets# read_data_sets是一个函数,专门用来读取mnist数据集。# 其中这个函数将训练集分为两部分,55000张图片作为训练集,5000张图片作为验证集。
'''tensorflow version: 1.0.0'''# 将以下代码保存为py文件,修改数据集路径后可直接运行from tensorflow.examples.tutorials.mnist import input_dataimport matplotlib.pyplot as plt# data_dir是数据集的路径,注意四个文件应当在同一个文件夹中data_sets = input_data.read_data_sets('data_dir')# 分别输出训练集,验证集,测试集的图片数量print('number of images in training set: %d' % data_sets.train.num_examples)print('number of images in validation set: %d' % data_sets.validation.num_examples)print('number of images in testing set: %d' % data_sets.test.num_examples)# 显示图片image, label = data_sets.train.next_batch(2)print('label: %d' % label[1])plt.imshow(image[1].reshape(28, 28), cmap='gray')plt.axis('off')plt.show()# 这里解释下为什么显示image[1]而不显示image[0]# image[0]图片显示像数字3,但标签label却是7
tensorflow-mnist
官方tensorflow-mnist代码,点这
官方tensorflow-mnist教程,点这(需翻墙)
'''tensorflow version: 1.0.0'''# 将这份文件保存为mnist_fcn.pyimport mathimport tensorflow as tf# 图片大小是28 * 28IMAGE_PIXELS = 28 * 28# 类别10个,分别对应数字0~9NUM_CLASSES = 10# 这个函数是创建一个隐层# 参数分别表示隐层的名字,流入数据,隐层单元个数及激活函数# 如果不使用激活函数则为输出层def layer(name, data_in, shape, activate): with tf.name_scope(name): weights = tf.Variable( tf.truncated_normal(shape, stddev=1 / math.sqrt(float(shape[0]))), name='weights' ) biases = tf.Variable( tf.zeros([shape[1]]), name='biases' ) if activate is None: # 注意网络的输出没有进行softmax data_out = tf.matmul(data_in, weights) + biases else: data_out = activate(tf.matmul(data_in, weights) + biases) return data_out# 定义网络结构# 参数分别表示输入图像,各隐层单元个数# 这个网络只有3个隐层,hidden_units = [x, y, z],分别表示各隐层神经元个数# 返回预测的值# 计算损失的函数tf.nn.sparse_softmax_cross_entropy_with_logits# 需要输入原始预测的值,所以这里不用进行softmax# 以下简单说下什么是softmax# 通俗来说对每个样本预测的10个值,每个值都表示属于0~9的一个概率# 但原始输出的10个值有正有负,大小也不在0~1之间,不能表示概率# 所以需要进行softmax,对预测值归一化,使其满足概率的定义# !!!!!!!!!!注意这里不需要归一化!!!!!!!!!!def model(images, hidden_units): data_out_h1 = layer( name='hidden1', data_in=images, shape=[IMAGE_PIXELS, hidden_units[0]], activate=tf.nn.relu ) data_out_h2 = layer( name='hidden2', data_in=data_out_h1, shape=[hidden_units[0], hidden_units[1]], activate=tf.nn.relu ) data_out_h3 = layer( name='hidden3', data_in=data_out_h2, shape=[hidden_units[1], hidden_units[2]], activate=tf.nn.relu ) logits = layer( name='softmax', data_in=data_out_h3, shape=[hidden_units[2], NUM_CLASSES], activate=None ) return logits# 定义损失函数# 参数分别是模型预测值,真实的值# 返回一个batch_size的平均损失def loss(logits, labels): cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits, name='cross_entropy' ) return tf.reduce_mean(cross_entropy, name='loss')# 定义优化器# 参数分别是损失值,学习率# 这里采用梯度下降算法# 返回训练网络的操作def optimizer(loss, learning_rate): opt = tf.train.GradientDescentOptimizer(learning_rate) train_op = opt.minimize(loss) return train_op# 定义评价函数# 参数分别是模型预测值,真实的值# 返回预测正确的个数# 比如说100张图片,预测正确80张,则返回80def evaluation(logits, labels): correct = tf.nn.in_top_k(logits, labels, 1) return tf.reduce_sum(tf.cast(correct, tf.int32))
'''tensorflow version: 1.0.0'''# 将这份文件保存为train_and_eval.py# 可直接运行,训练网络# 记得修改数据集的路径import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport mnist_fcn# 这个函数的作用是得到每次训练所需要的数据def fill_feed_dict(data_set, batch_size, images_pl, labels_pl): images, labels = data_set.next_batch(batch_size) feed_dict = { images_pl: images, labels_pl: labels } return feed_dict# 这个函数是每训练一定步数,就会在给定的数据集上测试性能# 比如在run_training函数中,每训练1000步# 会给出在训练集(training_set),验证集(validation set)和测试集(testing set)# 样本总数,正确分类的样本数和精度def do_eval(sess, data_set, batch_size, images_pl, labels_pl, eval_correct): true_counts = 0 steps = data_set.num_examples // batch_size num_examples = steps * batch_size for step in range(steps): feed_dict = fill_feed_dict( data_set, batch_size, images_pl, labels_pl ) true_counts += sess.run(eval_correct, feed_dict=feed_dict) precision = float(true_counts) / num_examples print('Num examples: %d, num correct: %d, precision @ 1 : %.4f' % (num_examples, true_counts, precision))def run_training(data_dir, batch_size, hidden_units, learning_rate, max_steps): # 读入数据,这时候data_sets包含3个数据集 # 训练集,验证集和测试集 data_sets = input_data.read_data_sets(data_dir) # 在默认的图中创建整个模型 with tf.Graph().as_default(): # 定义placeholder,用来向网络中传入数据 images_pl = tf.placeholder( tf.float32, shape=(batch_size, mnist_fcn.IMAGE_PIXELS) ) labels_pl = tf.placeholder( tf.int64, shape=(batch_size) ) # 图片输入网络,得到预测值 logits = mnist_fcn.model(images_pl, hidden_units) # 得到损失 loss = mnist_fcn.loss(logits, labels_pl) # 训练网络 train_op = mnist_fcn.optimizer(loss, learning_rate) # 得到一个batch_size中正确分类的图片张数 eval_correct = mnist_fcn.evaluation(logits, labels_pl) # 全局变量初始化,必备操作 init = tf.global_variables_initializer() # 默认图传入Session中 sess = tf.Session() # 初始化模型参数 sess.run(init) for step in range(max_steps): # 得到训练数据 feed_dict = fill_feed_dict( data_sets.train, batch_size, images_pl, labels_pl ) # 训练网络,得到损失函数值 _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) # 每100步输出损失值 if step % 100 == 0: print('step: %d loss = %.3f' % (step, loss_value)) # 每1000步在3个数据集上进行评估操作 # 判断模型是否朝好的方向训练 if (step + 1) % 1000 == 0 or (step + 1) == max_steps: print('Eval on training set') do_eval(sess, data_sets.train, batch_size, images_pl, labels_pl, eval_correct) print('Eval on validation set') do_eval(sess, data_sets.validation, batch_size, images_pl, labels_pl, eval_correct) print('Eval on test set') do_eval(sess, data_sets.test, batch_size, images_pl, labels_pl, eval_correct)if __name__ == '__main__': # 参数设定 run_training(data_dir='........', batch_size=100, hidden_units=[200, 400, 100], learning_rate=0.01, max_steps=50000)
数据集及代码下载
点这里!!!!!!
0 0
- tensorflow-mnist手写数字识别
- 基于tensorflow的MNIST手写数字识别
- 基于tensorflow的MNIST手写数字识别
- Tensorflow 实现 MNIST 手写数字识别
- 神经网络-tensorflow实现mnist手写数字识别
- tensorflow中mnist手写数字识别
- tensorflow中logistic识别mnist手写数字
- tensorflow中MLP识别mnist手写数字
- tensorflow构建RNN识别mnist手写数字
- TensorFlow学习---实现mnist手写数字识别
- TensorFlow实战—mnist手写数字识别
- tensorflow进行MNIST手写数字识别-CNN
- tensorflow进行MNIST手写数字识别-LSTM
- 训练Tensorflow识别手写数字 mnist
- TensorFlow笔记之一:MNIST手写数字识别
- Tensorflow入门 mnist手写数字识别
- Tensorflow , MNIST 识别你自己手写的数字
- Tensorflow MNIST 手写识别
- SQL 关于apply的两种形式cross apply 和 outer apply
- 关于Java的反射机制,你需要理解这些...
- React-router@4.0版本简易使用教程
- 递归实现单链表的查找
- Linux学习 第十一单元
- tensorflow-mnist手写数字识别
- HDU 4114 Disney's FastPass(floyd+状态压缩DP)
- 线段树(类似延迟标记) HDU
- CvArr、Mat、CvMat、IplImage、BYTE;QPixmap和QImage
- 用python脚本实现自动部署环境(一)
- 二叉树以及遍历算法
- 九大排序之——希尔排序
- ZOJ2975
- mybatis中插入数据自动返回自增长id的配置