tensflow实战——MNIST(1)

来源:互联网 发布:异次元杀人矩阵 编辑:程序博客网 时间:2024/06/08 14:03

注:此博客基于tensorflow官网完整教程,具体数据下载处可去http://www.tensorfly.cn/tfdoc/tutorials/mnist_download.html

MNIST是在机器学习领域中的一个经典问题。该问题解决的是把28x28像素的灰度手写数字图片识别为相应的数字,其中数字的范围从0到9.

60000行训练数据集 mnist.train
10000行测试数据集 mnist.test

mnist.train.images [60000,784] 维度1索引图片,维度2索引像素点
mnist.train.labels [60000,10] 标签数据”one-hot vectors”(一个one-hot向量除了一位数字为1以
外,其余为0)

1、下载安装数据集
提供一份自动下载和安装数据集 input_data.py

from tensorflow.examples.tutorials.mnist import input_datamnist1 = input_data.read_data_sets("MINST_data", one_hot=True)'''one-hot,Label是一个10维的向量,只有一个值为1,如果是数字0,那么对应的Label就是[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]。'''

2、定义
placeholder是占位符,第一个参数是数据类型dtype,第二个是tensor的shape。
Softmax Regression会对10类分别估算出一个概率,例如是0的概率为80%,数字1的概率是2%,那么它就会取最后那个概率最大的那个数

y=softmax(Wx+b)

import tensorflow as tfsess = tf.InteractiveSession()   # 使用这个命令会将这个session注册为默认的session,之后也会默认在这个session里跑。x = tf.placeholder(tf.float32, [None, 784])    '''接下来就是创建权重和偏差,这里因为就举个例子,所以就初始化为0就可以了,如果是其它复杂的例子,对初始化比较敏感的话,就不能这么简单的进行初始化了。'''W = tf.Variable(tf.zeros([784, 10]))b = tf.Variable(tf.zeros([10]))#Softmax Regression的实现y = tf.nn.softmax(tf.matmul(x, W) + b)

3、损失函数,优化算法
根据损失来找到最好的模型

H(y)=ylog(y)

y是预测的概率,y_是正确的标签

reduction_indices = [1]: 一种压缩方法具体见我的其他博文
reduce_mean:平均值
reduce_sum:求和
GradientDescentOptimizer(0.5):梯度下降,学习率为0.5

#交叉熵y_ = tf.placeholder(tf.float32, [None, 10])cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices = [1])) 1. #使用随机梯度下降进行优化,这里把学习率设为0.5,使用全局参数初始化器并直接执行它的run。train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)init=tf.global_variables_initializer()sess.run(init)

4、训练数据
迭代执行训练操作
迭代1000次,每次100

for i in range(1000):   batch = mnist1.train.next_batch(100)   sess.run(train_step,feed_dict={x: batch[0], y: batch[1]})

5、准确率

argmax函数,给出某个tensor对象在某一堆上其数据最大值的所在的索引值。
(y,1):y 所索引的向量,1表示按行索引,0表示按列索引。

#计算分类是否正确,给出一组布尔值correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))#计算准确率,先转换为浮点数,取平均值accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))print(accuracy.eval({x: mnist1.test.images, y_: mnist1.test.labels}))

此预测模型准确率大概为91%左右,准确率不够高,原因是因为这个模型比较简单!

原创粉丝点击