tensorflow 分类问题

来源:互联网 发布:金山卫士 源码 编辑:程序博客网 时间:2024/06/06 20:10

mnist 分类例子

随机找一张图片的张量(是一个长度784的数组),可以看到所有的像素点都进行了归一化。都小与1

[ 0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.38039219  0.37647063  0.3019608   0.46274513  0.2392157   0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.35294119  0.5411765  0.92156869  0.92156869  0.92156869  0.92156869  0.92156869  0.92156869  0.98431379  0.98431379  0.97254908  0.99607849  0.96078438  0.92156869  0.74509805  0.08235294  0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.  0.54901963  0.98431379  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.74117649  0.09019608  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.88627458  0.99607849  0.81568635  0.78039223  0.78039223  0.78039223  0.78039223  0.54509807  0.2392157  0.2392157   0.2392157   0.2392157   0.2392157   0.50196081  0.8705883  0.99607849  0.99607849  0.74117649  0.08235294  0.          0.          0.  0.          0.          0.          0.          0.          0.  0.14901961  0.32156864  0.0509804   0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.13333334  0.83529419  0.99607849  0.99607849  0.45098042  0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.32941177  0.99607849  0.99607849  0.91764712  0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.32941177  0.99607849  0.99607849  0.91764712  0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.41568631  0.6156863   0.99607849  0.99607849  0.95294124  0.20000002  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.09803922  0.45882356  0.89411771  0.89411771  0.89411771  0.99215692  0.99607849  0.99607849  0.99607849  0.99607849  0.94117653  0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.26666668  0.4666667   0.86274517  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.55686277  0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.14509805  0.73333335  0.99215692  0.99607849  0.99607849  0.99607849  0.87450987  0.80784321  0.80784321  0.29411766  0.26666668  0.84313732  0.99607849  0.99607849  0.45882356  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.44313729  0.8588236   0.99607849  0.94901967  0.89019614  0.45098042  0.34901962  0.12156864  0.          0.          0.          0.          0.7843138  0.99607849  0.9450981   0.16078432  0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.66274512  0.99607849  0.6901961   0.24313727  0.          0.  0.          0.          0.          0.          0.          0.18823531  0.90588242  0.99607849  0.91764712  0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.07058824  0.48627454  0.          0.          0.  0.          0.          0.          0.          0.          0.  0.32941177  0.99607849  0.99607849  0.65098041  0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.54509807  0.99607849  0.9333334   0.22352943  0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.  0.82352948  0.98039222  0.99607849  0.65882355  0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.94901967  0.99607849  0.93725497  0.22352943  0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.  0.34901962  0.98431379  0.9450981   0.33725491  0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.  0.01960784  0.80784321  0.96470594  0.6156863   0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.01568628  0.45882356  0.27058825  0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.  0.          0.          0.          0.          0.          0.          0.        ]

代码

import tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets import  mnistfrom tensorflow.examples.tutorials.mnist import input_dataimport numpy as np;mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)x = tf.placeholder(tf.float32, [None, 784])W = tf.Variable(tf.zeros([784,10]))b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x,W) + b)y_ = tf.placeholder("float", [None,10])cross_entropy = -tf.reduce_sum(y_*tf.log(y))train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)init = tf.initialize_all_variables()sess = tf.Session()sess.run(init)for i in range(100):  batch_xs, batch_ys = mnist.train.next_batch(20)  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))print sess.run(correct_prediction, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

例子2,对0-1 表示第一类,1-2表示第二类,2-3表示第三类。进行训练分类


#coding=utf-8import tensorflow as tfimport numpy as np#mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)x = tf.placeholder(tf.float32, [None, 1])W = tf.Variable(tf.zeros([1,3]))b = tf.Variable(tf.zeros([3]))wx=tf.matmul(x,W);y = tf.nn.softmax(wx + b)y_ = tf.placeholder("float", [None,3])cross_entropy = -tf.reduce_sum(y_*tf.log(y+1e-10))train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)init = tf.initialize_all_variables()sess = tf.Session()sess.run(init)#对0-100,100-200,200-300 进行分类x_data_1=np.linspace(0,1,100);x_data_2=np.linspace(1,2,100)x_data_3=np.linspace(2,3,100)x_data=np.concatenate((x_data_1,x_data_2,x_data_3)).reshape(-1,1)y_data=np.array([1,0,0]*100+[0,1,0]*100+[0,0,1]*100).reshape(-1,3)for i in range(200):  sess.run(train_step, feed_dict={x: x_data, y_: y_data})  correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_data, 1))  accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))  if i%50==0:    print "精确度:"+str(sess.run(accuracy, feed_dict={x: x_data, y_: y_data}))    print "损失函数值:" + str(sess.run(cross_entropy, feed_dict={x: x_data, y_: y_data}))print "结果:"+str(sess.run(tf.argmax(y, 1),feed_dict={x:[[0.7],[1.3],[2.2],[1.7],[0.11],[2.333]]}))

结果

精确度:0.333333损失函数值:354.704精确度:0.763333损失函数值:130.619精确度:0.916667损失函数值:67.9843精确度:0.986667损失函数值:52.7412结果:[0 1 2 1 0 2][0 1 2 1 0 2]  #根据下标,0代表第一类,1代表第二类,2代表第三类

由于训练数据自变量相差不是太大 ,如本例的自变量是0-1,1-2,2-3,所以训练结果还是很满意的


例子三,对0-100 表示第一类,100-200表示第二类,200-300表示第三类。进行训练分类

#coding=utf-8import tensorflow as tfimport numpy as np#mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)x = tf.placeholder(tf.float32, [None, 1])W = tf.Variable(tf.zeros([1,3]))b = tf.Variable(tf.zeros([3]))wx=tf.matmul(x,W);y = tf.nn.softmax(wx + b)y_ = tf.placeholder("float", [None,3])cross_entropy = -tf.reduce_sum(y_*tf.log(y+1e-10))train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)init = tf.initialize_all_variables()sess = tf.Session()sess.run(init)#对0-100,100-200,200-300 进行分类x_data_1=np.linspace(0,100,100);x_data_2=np.linspace(100,200,100)x_data_3=np.linspace(200,300,100)x_data=np.concatenate((x_data_1,x_data_2,x_data_3)).reshape(-1,1)y_data=np.array([1,0,0]*100+[0,1,0]*100+[0,0,1]*100).reshape(-1,3)for i in range(1000):  sess.run(train_step, feed_dict={x: x_data, y_: y_data})  correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_data, 1))  accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))  if i % 50 == 0:    print "精确度:" + str(sess.run(accuracy, feed_dict={x: x_data, y_: y_data}))    print "损失函数值:" + str(sess.run(cross_entropy, feed_dict={x: x_data, y_: y_data}))print "结果:" + str(sess.run(tf.argmax(y, 1),feed_dict={x:[[0.7],[1.3],[288],[1.7],[2222.2],[133.333]]}))

结果

精确度:0.333333损失函数值:4583.25精确度:0.336667损失函数值:4582.96精确度:0.336667损失函数值:4582.77精确度:0.336667损失函数值:4582.63精确度:0.336667损失函数值:4582.54精确度:0.336667损失函数值:4582.47精确度:0.336667损失函数值:4582.42精确度:0.336667损失函数值:4582.38精确度:0.336667损失函数值:4582.35精确度:0.336667损失函数值:4582.33精确度:0.336667损失函数值:4582.31精确度:0.336667损失函数值:4582.3精确度:0.336667损失函数值:4582.28精确度:0.336667损失函数值:4582.27精确度:0.336667损失函数值:4582.26精确度:0.336667损失函数值:4582.25精确度:0.336667损失函数值:4582.25精确度:0.336667损失函数值:4582.24精确度:0.336667损失函数值:4582.23精确度:0.336667损失函数值:4582.23结果:[2 2 2 2 2 2] #可以看出大部分都分错了

可以看出精确度一值没有提高,损失函数值减小很慢;接着定义的损失函数为 cross=-tf.reduce_sum(y_*tf.log(y+1e-10)),y_为实际,y为预测,当y=0时,log=exp(-10)=10,会导致cross值很大,损失值的和自然也很大。我门应该把原始自变量x进行规则化(变成某一范围的数),我们可以把x都除以某个值,本例子可以在算激活函数的时候,x同时除以100,即改为y = tf.nn.softmax(tf.matmul(x/100,W) + b)原因:分类问题使用的激活函数是softmax,定义为




可以看出,此函数是把所有的自变量的exp(**x)的和做为分母,exp(x)做为分子,自变量x越大,其所占的分子更大,由于自变量x相差比较大,导致一部分很大的自变量x,h(x)趋近于1,其他小x,h(x)的趋近于0。 就导致y=[0,0,1] 现象

改进代码:

#coding=utf-8import tensorflow as tfimport numpy as np#mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)x = tf.placeholder(tf.float32, [None, 1])W = tf.Variable(tf.zeros([1,3]))b = tf.Variable(tf.zeros([3]))y = tf.nn.softmax(tf.matmul(x/100,W) + b)y_ = tf.placeholder("float", [None,3])cross_entropy = -tf.reduce_sum(y_*tf.log(y+1e-10))train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)init = tf.initialize_all_variables()sess = tf.Session()sess.run(init)#对0-100,100-200,200-300 进行分类x_data_1=np.linspace(0,100,100);x_data_2=np.linspace(100,200,100)x_data_3=np.linspace(200,300,100)x_data=np.concatenate((x_data_1,x_data_2,x_data_3)).reshape(-1,1)y_data=np.array([1,0,0]*100+[0,1,0]*100+[0,0,1]*100).reshape(-1,3)for i in range(500):  sess.run(train_step, feed_dict={x: x_data, y_: y_data})  correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_data, 1))  accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))  if i % 50 == 0:    print "精确度:" + str(sess.run(accuracy, feed_dict={x: x_data, y_: y_data}))    print "损失函数值:" + str(sess.run(cross_entropy, feed_dict={x: x_data, y_: y_data}))print "结果:" + str(sess.run(tf.argmax(y, 1),feed_dict={x:[[0.7],[1.3],[288],[1.7],[2222.2],[133.333]]}))

结果:

精确度:0.333333损失函数值:354.704精确度:0.763333损失函数值:130.619精确度:0.916667损失函数值:67.9843精确度:0.986667损失函数值:52.7412精确度:0.986667损失函数值:48.5135精确度:0.986667损失函数值:45.4217精确度:0.986667损失函数值:43.0049精确度:0.99损失函数值:41.0391精确度:0.99损失函数值:39.3952精确度:0.993333损失函数值:37.9913结果:[0 0 2 0 2 1]

可以看出精确度提升很快,损失函数也越来越小,结果也完全正确

结论:

对于分类问题,使用softmax 函数作为激活函数,对数损失函数即-tf.reduce_sum(y_*tf.log(y+1e-10))当自变量x差别非常大时,应该对x进行规则化或者归一化,否则会导致梯度下降非常慢

0 0