Tensorflow第一个mnist数据集(理解)

来源:互联网 发布:手提旅行包 知乎 编辑:程序博客网 时间:2024/06/05 08:28

数据集

分为两部分

6000行的训练数据集 (mnist.train)

1000行的测试数据集(mnist.test)

在机器学习模型设计时必须要有一个单独的测试数据集不用来训练而是用来评估这个模型的性能,从而更加容易把设计的模型推广到其他数据集上 (泛化)

每个mnist数据单元有两部分组成 手写数字的图片(xs)和一个与之对应的标签(ys)训练数据集和测试数据集都包含xs ys 

例  训练数据集的图片是mnist.train.images 

训练数据集的标签是mnist.train.labels


 

每一张图片 包含28X28像素 用数组来表示这张照片我们把数组展开成一个向量 长度是28X28=784 每个照片用相同方式展开 mnist数据集的图片就是在784维向量空间里的点(拥有比较复杂的结构 此类数据的可视化是计算密集型的)

展开图片的数字数组会丢失图片的二维结构信息 优秀的计算机视觉方法会挖掘并利用这些结构信息 (后面学)  此次忽略这些结构 使用简单数学模型 softmax回归 不会利用这些结构信息。

在mnist训练数据集中,mnist.train.images是一个形状为 60000,784的张量, 第一维度数字用来索引图片 第二维度数字用来索引每张图片上的像素点,在此张量里的每个元素 都表示某张图片里的某个像素的强度值,0-1之间

相对应的mnist数据集的标签是介于0到9的数字 用来描述给定图片里表示的数字 使标签数据是one-hot vectors   一个One-hot向量除了某一位的数字是1以外其余各维度数字都是0 所以在此教程中数字n将表示成一个只有在第n维度数字为1的10维向量 比如 标签将表示成([1,0,0,0,0,0,0,0,0,0])因此 mnist.train.labels是一个[6000,10]的数字矩阵

Softmax 回归介绍(更多的关于Softmax函数的信息,可以参考Michael Nieslen的书里面的这个部分,其中有关于softmax的可交互式的可视化解释。)

      Mnist的每一张图片都表示一个数字, 0-9 我们的训练模型可能推测一张包含9的图片代表数字9的概率是80 但是判断8的概率是5 (上半部分都是小圆) 然后给予它代表其他数字的概率更小的值。

       这是一个使用softmax回归模型的经典案例。Softmax模型可以用来给不同的对象分配概率。 即使在训练更加精细的模型时,最后一步也需要用softmax来分配概率。

Softmax(二步)

第一步:对图片像素值进行加权求和(为了得到一张给定图片属于某个特定数字类的证据) 如果这个像素具有很强的证据说明这张图片不属于该类 那么相应的权值为负数相反 如果这个像素拥有有利的证据支持这张图片属于这个类,那么权值是正数。

       下面图片显示了一个模型学习到的图片上每个像素对于特定数字类的权值。 红色表示负数权值  蓝色代表正数权值。

我们也需要 加入一个额外的偏执量(bias)因为输入往往会带来一些无关的干扰量,因此对于给定的输入图片X它代表的是数字id证据可以表示为

      

其中Wi 代表权重 

bi:数字i类的偏执量

J: 给定图片X的像素索引用于像素求和

然后用softmax函数可以把这些证据转换成概率Y:

这里的softmax 可以看出是一个激励(activation)函数或者链接(link)函数, 把我们定义的线性函数的输出转换成我们需要的格式,也就是关于10个数字类的概率分布。因此 给定一张照片 它对于每一个数字的吻合度可以被softmax函数转换成为一个概率值

Softmax函数可以定义为:

展开右边式子:

但是更多的时候把softmax 模型函数定义为前一种形式:

把输入值当成幂指数求值 再正则化这些结果值。

这个幂运算表示,更大的证据对应更大的假设模型里面的乘数权重值。反之,拥有更少的证据意味着在假设模型里面拥有更小的乘数系数。假设模型里的权值不可以是0值或者负值。Softmax 然后会 正则化这些权重值,使他们的总和等于1  ,以此构造一个有效的概率分布。

       对于softmax 回归模型下图解释 对于输入的XS加权求和,再分别加上一个偏置量 最后再输入到softmax函数中:

把它写成一个等式,可以得到:

也可以用向量表示这个计算过程:用矩阵乘法和向量相加。有助于提高计算效率:

更进一步可以写成更加紧凑的方式:

实现回归模型          

使用Numpy函数库(python库、实现高效的数值计算) 把类似矩阵乘法这样的复杂运算使用其他外部语言实现。 但是 从外部计算切换回python的每一个操作,仍然是一个很大的开销。如果你用GPU来进行外部计算,这样开销会更大。用分布式的计算方式,也会花费更多的资源用来传输数据。

Tensorflow也把复杂的计算放在python之外完成,为了避免多余的开销。它做了进一步完善。Tensorflow 不单独地运行单一的复杂计算,而是让我们可以先用图描述一系列可交互的计算操作,然后全部一起在python之外运行。(这样的运行方式 可以在不少的机器学习库中看到)

使用tensorflow之前首先导入;

Impor tensorflow as tf

我们通过操作符合变量来描述这些可交互的操作单元,创建:

X = tf.placeholder(‘float’,[None,784])

X不是一个特定的值而是一个占位符placeholder,tensorflow运行计算时输入这个值,我们希望能够输入任意数量的mnist图像,每一张图展开成784维的向量。我们用2维的浮点数张量来表示这些图,这个张量的形状是[none, 784](none 表示此张量的第一个维度可以是任何长度)

我们的模型也需要权重值和偏置量, 当然我们可以把他们当作是另外的输入(使用占位符)但tensorflow有一个更好的方法来表示他们:variable  一个variable代表一个可修改的张量 存在在tensorflow的用于描述交互性操作的图中。它们可以用于计算输入值,也可以在计算中被修改 对于机器学习应用,一般都会有模型参数用variable表示。

w = tf.variable(tf.zeros([784,10]))

B = tf.variable(tf.zeros([10]))

我们赋予tf.variable不同的初值来创建不同的variable :用全为零的张量来初始化W和B  (wb值可以初值可以随意设置)

W的维度是[784,10] 因为要用784维的图片向量乘以它以得到一个10维的证据值向量,每一位对应不同数字类。B的形状是[10],所以可以直接把它加到输出上面。

代码: y= tf.nn.softmax(tf.matmul(x,W)+b)

用tf.matmul(x,W)表示X乘W,对应之前等式里面的Wx  x是一个2维张量拥有多个输入。然后加B 把和输入到tf.nn.softmax函数里面。 

 

训练模型

为了训练模型 首先需要定义一个指标来评估这个模型是好的。 (在机器学习中通常定义指标来表示一个模型是坏的 这个指标成为 成本或损失 然后尽量最小化这个指标 但是这两种方式是相同的)

一个非常常见的成本函数是 交叉熵 (cross-entropy)(交叉熵产生于信息论里面的信息压缩编码技术,后来演变成为从博弈论到机器学习等其他领域里的重要技术手段)定义如下:

              

Y是预测的概率分布 y`是实际的分布(输入的one-hot vector)

比较粗糙的理解是,交叉熵是用来衡量我们的预测用于描述真相的低效性。要详细理解!

为了计算交叉熵 首先需要添加一个新的占位符用于输入正确值:

Y_=tf.placeholder(‘float’,[none,10])

然后用

        计算交叉熵:

Cross_entropy =-tf.reduce_sum(y_*tf.log(y))

首先 用tf.log计算y的每一个元素的对数。

接下来 把y_的每一个元素和tf.log(y)的对应元素相乘。

最后  用tf.reduce_sum计算张量的所有元素的总和。(这里的交叉熵不仅是用来衡量单一的一对预测和真实值,而是所有100张图片的交叉熵的总和 。100个数据点的预测表现比单一数据点的表现能更好地描述我们的模型的性能)             

       知道模型的原理,用tensorflow来训练是非常容易的 因为tensorflow拥有一张描述你各个计算单元的图,它可以自动地使用反向传播算法来有效的确定你的变量是如何影响你想要最小化的那个成本值的 然后 tensorflow会用你选择的优化算法来不断的修改变量以降低成本。

Tran_step =tf.tran.GradientDescentOptimizer(0.01).minmize(cross_entropy)  

在此 要求tensorflow使用梯度下降的算法以0.01的学习速率最小化交叉熵。梯度下降算法是一个简单的学习过程 tensorflow只需要将每个变量一点点地往成本不断降低的方向移动。(当然tensorflow也提供了许多 优化算法 只要简单的调整一行代码就可以使用其他的算法)

Tensorflow实际上所做的是,在后台给描述你的计算的那张图里面增加一系列新的计算操作单元用于实现反响传播算法和梯度下降算法训练你的模型,微调你的变量,不断减少成本。

现在,我们已经设置好我们的模型。在运行计算之前,我们需要添加一个操作来初始化我们创建的变量:

Init =tf.initialize_all_variables()

现在我们可以在一个 Session里面启动我们的模型,并且初始化变量:

sess = tf.Session()

sees.run(init)

然后开始训练模型,我们让模型循环训练1000次

For  i in  range(1000):

       batch_xs , batch_ys =mnist.train.next_batch(100)

       sess.run(train_setp , feed_dict={x: batch_xs, y_: batch_ys})

该循环的每一步都会随机抓取训练数据中的100个批处理数理点,然后我们用这些数据点作为参数替换之前的占位符来运行train_step。

使用一小部分的随机数据来进行训练被称为随机训练。在这里更确切的说是随机梯度下降训练。在理想情况下,我们希望用我们所有的数据来进行每一步的训练,因为这能给我们更好的训练结果,但显然这需要很大的计算开销,又可以最大化地学习到数据集的总体特征。

评估模型

首先找出预测正确的标签。 tf.argmax是一个非常有用的函数,它给出某个tensor对象在某一维上的其数据最大值所在的索引值。由于标签向量是由0,1组成,因此最大值1所在的索引 位置就是类别标签,比如 tf.argmax(y,1)返回的是模型对于任一输入x预测到的标签值,而tf.argmax(y_,1)代表正确的标签,我们可以用tf.equal来检测我们的预测是否真实标签匹配(索引位置一样表示匹配)

correct_prediction= tf.equal(tf.argmax(y, 1), tf.argmax(y_,1))

这行代码会给我们一组布尔值,为了确定正确预测项的比例, 可以把布尔值转换成浮点数,然后取平均值,例如 [true, flase, true, true]会变成[1,0,1,1]取平均值后得到 0.75

Accuracy =tf.reduce_mean(tf.cast(correct_prediction, ‘float’))

最后 计算模型在测试数据集上面的正确率

printsess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})  

 

结果值大约是91%

这个值并不太好 (很差) 。因为我们仅仅使用了一个非常简单的模型 不过 做一些小小的改进 ,就可以得到97的正确率。最后的模型甚至可以超过 99.7的准确率(想要更多信息可以看这个关于各种模型的性能对比列表)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         

阅读全文
0 0
原创粉丝点击