softmax regression的tensorflow实现

来源:互联网 发布:php广告系统 编辑:程序博客网 时间:2024/06/06 02:06

MNIST数据集的使用是机器学习领域的HelloWorld.
他由几万张28x28像素的图片组成,这些图片只包含灰度信息,我们要做的就是对这些图片进行分类,分为0-9共10类.
softmax regression 模型在对图片进行预测时会为每个类估算一个概率,最后取概率大的为输出结果。
处理多分类的问题通常使用该模型,CNN和RNN最后一层同样是Softmax Regression
这是代码部分

from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)print(mnist.train.images.shape, mnist.train.labels.shape)print(mnist.test.images.shape, mnist.test.labels.shape)print(mnist.validation.images.shape, mnist.validation.labels.shape)import tensorflow as tfsess = tf.InteractiveSession()x = tf.placeholder(tf.float32, [None, 784])W = tf.Variable(tf.zeros([784, 10]))b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x, W) + b)y_ = tf.placeholder(tf.float32, [None, 10])cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)tf.global_variables_initializer().run()for i in range(1000):   batch_xs, batch_ys = mnist.train.next_batch(100)   train_step.run({x: batch_xs, y_: batch_ys})correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

图中附有我写的解析:

这里写图片描述

这是实验结果:

这里写图片描述

建议用jupyter notebook打开,编写环境tensorflow1.1,python3.5。

点此查看Github源代码

原创粉丝点击