tensflow实战——MNIST(1)
来源:互联网 发布:异次元杀人矩阵 编辑:程序博客网 时间:2024/06/08 14:03
注:此博客基于tensorflow官网完整教程,具体数据下载处可去http://www.tensorfly.cn/tfdoc/tutorials/mnist_download.html
MNIST是在机器学习领域中的一个经典问题。该问题解决的是把28x28像素的灰度手写数字图片识别为相应的数字,其中数字的范围从0到9.
60000行训练数据集 mnist.train
10000行测试数据集 mnist.test
mnist.train.images [60000,784] 维度1索引图片,维度2索引像素点
mnist.train.labels [60000,10] 标签数据”one-hot vectors”(一个one-hot向量除了一位数字为1以
外,其余为0)
1、下载安装数据集
提供一份自动下载和安装数据集 input_data.py
from tensorflow.examples.tutorials.mnist import input_datamnist1 = input_data.read_data_sets("MINST_data", one_hot=True)'''one-hot,Label是一个10维的向量,只有一个值为1,如果是数字0,那么对应的Label就是[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]。'''
2、定义
placeholder是占位符,第一个参数是数据类型dtype,第二个是tensor的shape。
Softmax Regression会对10类分别估算出一个概率,例如是0的概率为80%,数字1的概率是2%,那么它就会取最后那个概率最大的那个数
import tensorflow as tfsess = tf.InteractiveSession() # 使用这个命令会将这个session注册为默认的session,之后也会默认在这个session里跑。x = tf.placeholder(tf.float32, [None, 784]) '''接下来就是创建权重和偏差,这里因为就举个例子,所以就初始化为0就可以了,如果是其它复杂的例子,对初始化比较敏感的话,就不能这么简单的进行初始化了。'''W = tf.Variable(tf.zeros([784, 10]))b = tf.Variable(tf.zeros([10]))#Softmax Regression的实现y = tf.nn.softmax(tf.matmul(x, W) + b)
3、损失函数,优化算法
根据损失来找到最好的模型
y是预测的概率,y_是正确的标签
reduction_indices = [1]: 一种压缩方法具体见我的其他博文
reduce_mean:平均值
reduce_sum:求和
GradientDescentOptimizer(0.5):梯度下降,学习率为0.5
#交叉熵y_ = tf.placeholder(tf.float32, [None, 10])cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices = [1])) 1. #使用随机梯度下降进行优化,这里把学习率设为0.5,使用全局参数初始化器并直接执行它的run。train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)init=tf.global_variables_initializer()sess.run(init)
4、训练数据
迭代执行训练操作
迭代1000次,每次100
for i in range(1000): batch = mnist1.train.next_batch(100) sess.run(train_step,feed_dict={x: batch[0], y: batch[1]})
5、准确率
argmax函数,给出某个tensor对象在某一堆上其数据最大值的所在的索引值。
(y,1):y 所索引的向量,1表示按行索引,0表示按列索引。
#计算分类是否正确,给出一组布尔值correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))#计算准确率,先转换为浮点数,取平均值accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))print(accuracy.eval({x: mnist1.test.images, y_: mnist1.test.labels}))
此预测模型准确率大概为91%左右,准确率不够高,原因是因为这个模型比较简单!
- tensflow实战——MNIST(1)
- Tensflow学习笔记(一)——TF生成并查看数据
- TensorFlow实战—mnist手写数字识别
- TensorFlow实战——CNN(LeNet5)——MNIST数字识别
- TensorFlow实战——DNN——MNIST数字识别
- MNIST实战
- tensorflow mnist实战笔记(一)了解官方mnist数据格式
- Tensorflow实战1:利用AlexNet训练MNIST
- tensorflow实战1:lstm实现mnist分类
- 罗斯基白话:TensorFlow + 实战系列(五)实战MNIST
- Tensorflow学习系列(1)——MNIST手写识别
- mnist——prototxt
- 深入MNIST(1)
- 深度学习系列——windows平台下跑微软caffe实战之运行mnist,cifar10
- caffe mnist数据实战
- caffe mnist实战训练(4)训练得出模型
- caffe mnist实战训练(5)用可视化图片解决问题
- TensorFlow实战-mnist手写数字识别(卷积神经网络)
- POJ1469[COURSES] 二分图最大匹配 匈牙利算法
- poj 3613(还是不懂,以后再看看)
- python里给出一个列表,怎么样从列表里取出最小两项的索引值
- 计蒜客 最大子阵列
- 十分钟学会pandas《10 Minutes to pandas》
- tensflow实战——MNIST(1)
- 写给自己的信
- Java-基础
- 基于特定领域国土GIS应用框架设计及应用
- 欢迎使用CSDN-markdown编辑器
- 数据结构——线性结构(3)——链栈的实现
- hibernate在用注解设置字段的默认值时遇到的问题
- HDU1556 color the ball (树状数组)
- 【Git】远程仓库的使用