tensorflow:mnis入门代码注释

来源:互联网 发布:python中iteritems 编辑:程序博客网 时间:2024/05/18 01:56

把tensorflow教程中的mnist入门代码做了一个比较具体的注释,方便别人阅读也方便自己回顾。

# -*- coding: UTF-8 -*-  #导入tensorflow模块,用tf来表示import tensorflow as tf#导入tensorflow模块中的input_data文件from tensorflow.examples.tutorials.mnist import input_data#使用input_data文件中的read_data_sets函数读取mnist数据集#建立模型mnist = input_data.read_data_sets("MNIST_data/",one_hot = True)x = tf.placeholder("float",[None,784])#创建占位符xw = tf.Variable(tf.zeros([784,10]))   #创建表示权重的tensor-wb = tf.Variable(tf.zeros([10]))       #创建表示偏置的tensor-by = tf.nn.softmax(tf.matmul(x,w)+b)   #根据w.x+b的结果创建softmax模型#计算交叉熵y_ = tf.placeholder("float",[None,10])#创建占位符y_cross_entropy = -tf.reduce_sum(y_*tf.log(y))#定义损失函数的表达式#训练模型train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)#定义计算梯度下降法的学习速率lr以及要最优化的损失函数init = tf.initialize_all_variables()  #初始化所有变量的操作(op)sess = tf.Session()  #创建一个会话sess.run(init)for i in range(1000):    batch_xs,batch_ys = mnist.train.next_batch(50)  #每次循环调用此函数读取训练集中的50个数据和标签    sess.run(train_step,feed_dict={x:batch_xs,y_:batch_ys})#sess.run(m)表示程序想要获取m的值,此处m是train_step,然后将之前的占位符x和y_分别用数据集中的数据和标签代替    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))#分别取出y以及y_的最大值所在的索引,即1所在的索引,并判断他俩是否相等accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))#将判断的结果由bool转化为float,并求其平均值print sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})#sess.run同上,即到了这一步才会去执行之前预定的那些操作op,最后print输出结果zui
最后训练出来的模型在测试集上的准确率在91%左右,是非常不好的,原因也是这里只是用了一个softmax regression,后面将使用比这个复杂的模型来进行训练,准确率能够达到99%以上,后面博客再写。
1 0
原创粉丝点击