Softmax Regression with MNIST
来源:互联网 发布:淘宝上好的牛排店 编辑:程序博客网 时间:2024/05/21 01:55
本文通过搭建Softmax Regression,并用MNIST数据集进行训练以及测试,介绍tensorflow的最基础使用方式。
MNIST数据集介绍以及Softmax回归介绍参考:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html
MNIST数据集导入
通过调用read_data_sets(),第一个参数填MNIST数据集存储路径,函数会自动判断当前路径下是否下载好数据,是否需要重新下载。
import tensorflow.examples.tutorials.mnist.input_data as input_datamnist = input_data.read_data_sets("MNIST_data",one_hot=True)
Softmax回归模型搭建
# Create the model#通过操作符号变量创建一个可交互的操作单元x = tf.placeholder(dtype=tf.float32, shape=[None, 784])#权重值和偏置量的创建w = tf.Variable(tf.zeros(shape=[784, 10]))b = tf.Variable(tf.zeros(shape=[10]))#Softmax模型创建y = tf.matmul(x, w) + b;# Define loss and optimizery_ = tf.placeholder(tf.float32, [None, 10])
训练模型存储
#模型启动sess = tf.InteractiveSession()saver=tf.train.Saver()def train(): #交叉熵计算 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) #执行反向传播 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) tf.global_variables_initializer().run() # Train for _ in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})train()#模型存储,默认存储路径为工程同目录下文件夹saver.save(sess,save_path='./model/mnistmodel.ckpt')
训练模型载入
载入模型时,必须先完整还原网络结构的所有参数
import tensorflow as tfimport numpy as npimport tensorflow.examples.tutorials.mnist.input_data as input_datamnist = input_data.read_data_sets("MNIST_data",one_hot=True)myGraph = tf.Graph()#还原网络结构x = tf.placeholder(dtype=tf.float32, shape=[None, 784])w = tf.Variable(tf.zeros(shape=[784, 10]))b = tf.Variable(tf.zeros(shape=[10]))y = tf.matmul(x, w) + b;# Define loss and optimizery_ = tf.placeholder(tf.float32, [None, 10])# Test trained model#提取变量saver = tf.train.Saver()with tf.Session() as sess: saver.restore(sess,'model/mnistmodel.ckpt') print('Weight:\n',sess.run(w)) print('biases:\n',sess.run(b)) #test correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))_: mnist.test.labels}))
阅读全文
0 0
- Softmax Regression with MNIST
- TensorFlow入门-MNIST & softmax regression
- MNIST和softmax回归(softmax regression)
- Tensorflow实现Softmax Regression 手写识别MNIST
- Tensorflow学习笔记(3)-mnist(softmax regression)
- TensorFlow的softmax regression做mnist例子
- TensorFlow学习笔记(2)----Softmax Regression分类MNIST
- MNIST(一):最简单的softmax regression 模型训练
- tensorflow实现softmax回归(softmax regression)——简单的MNIST识别(第一课)
- SoftMax regression
- Softmax Regression
- softmax regression
- Softmax Regression
- Softmax Regression
- Softmax regression
- Softmax Regression
- softmax回归(Softmax Regression)
- Tensorflow的Helloword:使用简单Softmax Regression模型来识别Mnist手写数字
- JavaScript日期操作
- 扩大按钮UIButton的点击范围
- PCManFTP2.0漏洞分析
- Android注解使用之Annotation实现原理
- getTranslationX与getLeft()的联系
- Softmax Regression with MNIST
- matlab图像锐化
- Spring中ClassPathXmlApplicationContext类的简单使用
- cocos2d-x 3.x内存管理
- ios集成ijkplayer框架
- 按bit写入的性能小测试
- Jna
- leetcode 661. Image Smoother
- cscope 使用时打开新的窗口