【深度学习】Tensorflow学习记录(一) softmax regression mnist训练

来源:互联网 发布:商业地产it管理系统 编辑:程序博客网 时间:2024/04/30 10:49

之前学了2个月的caffe,最近打算开始学TensorFlow,这里记录相关的学习、实践测试笔记。
TensorFlow安装教程 (局部安装conda方法):
http://blog.csdn.net/bitcarmanlee/article/details/52749488

入门笔记

TensorFlow是由Google开发第二代(基于DistBelief)分布式的机器学习算法实现框架和部署系统,前端支持Python,C++,Go,Java等多种语言,后端使用C++,CUDA等写成,可在众多异构系统上方便地移植,CPU,GPU集群,iOS,Android等。
Github网址: https://github.com/tensorflow/tensorflow
模型仓库网址:https://github.com/tensorflow/models
1、模型简介
TensorFlow采用数据流式图(DataFlow-like)来规划计算流程,其中的计算可表示为计算图(computation graph),其实就是有向图,每个运算操作作为结点,用边连接。在计算图中流动的数据称为张量(tensor),得名TensorFlow。
2、常用组件
(1)google在2016年2月开源了TensorFlow Serving组件,可将训练好的模型导出,并部署成可以对外提供预测服务的RESTful接口。有了这个组件,TensorFlow就可以实现应用机器学习的全流程:训练,调试参数,打包膜性,部署服务。
(2)tensorBoard是TensorFlow的一组Web应用,用来监控TensorFlow的运行过程,可视化computation graph。可以用来持续监控运行时的关键指标,loss,learning rate,验证集的accuracy等。
3、TensorFlow实现Softmax regression识别手写数字

首先进入python/ipython下载运行这段代码,
然后执行
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets(“MNIST_data/”, one_hot=True)

两段代码导入数据。
这里写图片描述
查看mnist数据集情况:
这里写图片描述
下面运用softmax regression来实现手写数字识别:
首先载入tf并创建interactiveSession
这里写图片描述

  #创建placeholder保存输入数据 x=tf.placeholder(tf.float32,[None,784])  #weights,biases创建Variable对象 w=tf.Variable(tf.zeros([784,10])) b=tf.Variable(tf.zeros([10])) #实现softmax regression算法 y=tf.nn.softmax(tf.matmul(x+w)+b)

为了训练模型,需要定义loss function,通常使用cross-entropy来作为loss func,cross-entropy定义来自信息论

Hy(y)=iyilog(yi)

y’是真实概率分布(label的one-hot编码),y是预测概率分布,通常用上述公式判断模型对真实概率奋不顾级的准确度。
在tf中定义cross-entropy:
这里写图片描述

#使用tf全局参数初始化器,并直接执行:tf.global_variables_initializer().run()#开始迭代执行训练train_step,每次随机从样本中去100条样本构成mini-batch

for i in range(1000):
…… batch_xs, batch_ys = mnist.train.next_batch(100)
…… sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

评估模型:
首先让我们找出那些预测正确的标签。tf.argmax 是一个非常有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。由于标签向量是由0,1组成,因此最大值1所在的索引位置就是类别标签,比如tf.argmax(y,1)返回的是模型对于任一输入x预测到的标签值,而 tf.argmax(y_,1) 代表正确的标签,我们可以用 tf.equal 来检测我们的预测是否真实标签匹配(索引位置一样表示匹配)。

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

这行代码会给我们一组布尔值。为了确定正确预测项的比例,我们可以把布尔值转换成浮点数,然后取平均值。例如,[True, False, True, True] 会变成 [1,0,1,1] ,取平均值后得到 0.75.

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

整个过程参见下图,准确率92%左右:
这里写图片描述
补充图:
这里写图片描述

0 0
原创粉丝点击