MNISTone----用的softmax

来源:互联网 发布:未闻花名但知花香出处 编辑:程序博客网 时间:2024/05/29 03:23

# -*- coding: UTF-8 -*- '''Created on 2017年12月8日'''#以下两句用于下载数据import tensorflow.examples.tutorials.mnist.input_data as input_dataimport tensorflow as tf  mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)     #下载并加载mnist数据#输入输出占位符x = tf.placeholder(tf.float32,[None, 784]) #图像输入向量,占位符,每一个sample都是784维,none表示可以有任意个sample y_ = tf.placeholder("float", [None,10])  #占位符,每一个sample都是10维,因为是one_hot#参数W = tf.Variable(tf.zeros([784,10]))  #权重,初始化值为全零,变量b = tf.Variable(tf.zeros([10]))  #偏置,初始化值为全零,变量#进行模型建立及计算,y是预测,y_ 是实际  y = tf.nn.softmax(tf.matmul(x,W) + b)    #计算交叉熵  cross_entropy = -tf.reduce_sum(y_*tf.log(y+1e-10))  tf.scalar_summary('cross_entropy',cross_entropy) #接下来使用BP算法来进行微调,以0.01的学习速率,使用的是简单的梯度下降算法----记住,这是一个优化算子train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  #上面设置好了模型,添加初始化创建变量的操作  init = tf.initialize_all_variables()  #启动创建的模型,并初始化变量  sess = tf.Session()  sess.run(init)    #init也是操作merged = tf.merge_all_summaries() #collect the tf.xxxxx_summary  writer = tf.train.SummaryWriter('/home/tensorBoardLog/MNISTone',sess.graph)   #开始训练模型,循环训练1000次  for i in range(1000):      #随机抓取训练数据中的100个批处理数据点      batch_xs, batch_ys = mnist.train.next_batch(100)      #mnist.train #这里边有疑问:next_batch is a method of the DataSet class    #https://stackoverflow.com/questions/40368697/where-does-next-batch-in-the-tensorflow-tutorial-batch-xs-batch-ys-mnist-trai    #可以在github上看到    summary,loss, _= sess.run([merged, cross_entropy, train_step], feed_dict={x:batch_xs,y_:batch_ys})  #train_step是一个操作,step表示每一步    #注意是操作;给模型必要的输入,以及必要的操作指示    writer.add_summary(summary, i)    print('range: %04d, loss = %-9f' % (i+1, loss))''''' 进行模型评估 '''  #判断预测标签和实际标签是否匹配  correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))   #1表示在1轴上,0轴表示的是样本indexaccuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))  #tf.cast是类型转换函数#计算所学习到的模型在测试数据集上面的正确率  print( sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels}) )  #mnist.test,注意了,x,y只是占位符,train test都可以用#accuracy,注意了,这个时候w,b不会再变了,所以x进去自然会有一个y出来;feed_dict表示输入字典

ss