Softmax Regression with MNIST

来源:互联网 发布:淘宝上好的牛排店 编辑:程序博客网 时间:2024/05/21 01:55

本文通过搭建Softmax Regression,并用MNIST数据集进行训练以及测试,介绍tensorflow的最基础使用方式。

MNIST数据集介绍以及Softmax回归介绍参考:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html

  1. MNIST数据集导入

    通过调用read_data_sets(),第一个参数填MNIST数据集存储路径,函数会自动判断当前路径下是否下载好数据,是否需要重新下载。

    import tensorflow.examples.tutorials.mnist.input_data as input_datamnist = input_data.read_data_sets("MNIST_data",one_hot=True)
  2. Softmax回归模型搭建

    # Create the model#通过操作符号变量创建一个可交互的操作单元x = tf.placeholder(dtype=tf.float32, shape=[None, 784])#权重值和偏置量的创建w = tf.Variable(tf.zeros(shape=[784, 10]))b = tf.Variable(tf.zeros(shape=[10]))#Softmax模型创建y = tf.matmul(x, w) + b;# Define loss and optimizery_ = tf.placeholder(tf.float32, [None, 10])
  3. 训练模型存储

    #模型启动sess = tf.InteractiveSession()saver=tf.train.Saver()def train():  #交叉熵计算  cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))  #执行反向传播  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)  tf.global_variables_initializer().run()  # Train  for _ in range(1000):    batch_xs, batch_ys = mnist.train.next_batch(100)    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})train()#模型存储,默认存储路径为工程同目录下文件夹saver.save(sess,save_path='./model/mnistmodel.ckpt')
  4. 训练模型载入

    载入模型时,必须先完整还原网络结构的所有参数

    import tensorflow as tfimport numpy as npimport tensorflow.examples.tutorials.mnist.input_data as input_datamnist = input_data.read_data_sets("MNIST_data",one_hot=True)myGraph = tf.Graph()#还原网络结构x = tf.placeholder(dtype=tf.float32, shape=[None, 784])w = tf.Variable(tf.zeros(shape=[784, 10]))b = tf.Variable(tf.zeros(shape=[10]))y = tf.matmul(x, w) + b;# Define loss and optimizery_ = tf.placeholder(tf.float32, [None, 10])# Test trained model#提取变量saver = tf.train.Saver()with tf.Session() as sess:    saver.restore(sess,'model/mnistmodel.ckpt')    print('Weight:\n',sess.run(w))    print('biases:\n',sess.run(b))    #test    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))    print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))_: mnist.test.labels}))
原创粉丝点击