MNIST数字识别问题
来源:互联网 发布:android语音播报源码 编辑:程序博客网 时间:2024/06/06 01:03
关于··简单的MNIST数据 简单的 show 一下吧
import tensorflow as tfimport numpy as npfrom tensorflow.examples.tutorials.mnist import input_data#数据mnist = input_data.read_data_sets('MNIST_data',one_hot = True)#打印 training data sizeprint ("training data size:",mnist.train.num_examples)#打印 validating data sizeprint ('validating data size:',mnist.validation.num_examples)#打印 testing data sizeprint('testing data size:',mnist.test.num_examples)# 打印 training dataprint (' example training data \n', mnist.train.images[0])#打印 traing lableprint (' example training data label:',mnist.train.labels[0])
这段代码看得眼花···实在是有点··繁琐····希望以后能把它改得简单一点吧,使用Tensorflow训练神经网络,使用了带指数衰减的学习率设置,使用了正则化来避免过度拟合,以及使用了滑动平均模型来使得最终模型更加健硕
import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataINPUT_NODE = 784 # 输入节点数OUTPUT_NODE = 10 # 输出节点数LAYER1_NODE = 500 # 隐含层节点数BATCH_SIZE = 100LEARNING_RETE_BASE = 0.8 # 基学习率LEARNING_RETE_DECAY = 0.99 # 学习率的衰减率REGULARIZATION_RATE = 0.0001 # 正则化项的权重系数TRAINING_STEPS = 10000 # 迭代训练次数MOVING_AVERAGE_DECAY = 0.99 # 滑动平均的衰减系数# 传入神经网络的权重和偏置,计算神经网络前向传播的结果def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2): # 判断是否传入ExponentialMovingAverage类对象 if avg_class == None: layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1) return tf.matmul(layer1, weights2) + biases2 else: layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1)) return tf.matmul(layer1, avg_class.average(weights2))\ + avg_class.average(biases2)# 神经网络模型的训练过程def train(mnist): x = tf.placeholder(tf.float32, [None,INPUT_NODE], name='x-input') y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input') # 定义神经网络结构的参数 weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1)) biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE])) weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1)) biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE])) # 计算非滑动平均模型下的参数的前向传播的结果 y = inference(x, None, weights1, biases1, weights2, biases2) global_step = tf.Variable(0, trainable=False) # 定义存储当前迭代训练轮数的变量 # 定义ExponentialMovingAverage类对象 variable_averages = tf.train.ExponentialMovingAverage( MOVING_AVERAGE_DECAY, global_step) # 传入当前迭代轮数参数 # 定义对所有可训练变量trainable_variables进行更新滑动平均值的操作op variables_averages_op = variable_averages.apply(tf.trainable_variables()) # 计算滑动模型下的参数的前向传播的结果 average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2) # 定义交叉熵损失值 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=y, labels=tf.argmax(y_, 1)) cross_entropy_mean = tf.reduce_mean(cross_entropy) # 定义L2正则化器并对weights1和weights2正则化 regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) regularization = regularizer(weights1) + regularizer(weights2) loss = cross_entropy_mean + regularization # 总损失值 # 定义指数衰减学习率 learning_rate = tf.train.exponential_decay(LEARNING_RETE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARNING_RETE_DECAY) # 定义梯度下降操作op,global_step参数可实现自加1运算 train_step = tf.train.GradientDescentOptimizer(learning_rate)\ .minimize(loss, global_step=global_step) # 组合两个操作op train_op = tf.group(train_step, variables_averages_op) ''' # 与tf.group()等价的语句 with tf.control_dependencies([train_step, variables_averages_op]): train_op = tf.no_op(name='train') ''' # 定义准确率 # 在最终预测的时候,神经网络的输出采用的是经过滑动平均的前向传播计算结果 correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 初始化回话sess并开始迭代训练 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 验证集待喂入数据 validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} # 测试集待喂入数据 test_feed = {x: mnist.test.images, y_: mnist.test.labels} for i in range(TRAINING_STEPS): if i % 1000 == 0: validate_acc = sess.run(accuracy, feed_dict=validate_feed) print('After %d training steps, validation accuracy' ' using average model is %f' % (i, validate_acc)) xs, ys = mnist.train.next_batch(BATCH_SIZE) sess.run(train_op, feed_dict={x: xs, y_:ys}) test_acc = sess.run(accuracy, feed_dict=test_feed) print('After %d training steps, test accuracy' ' using average model is %f' % (TRAINING_STEPS, test_acc))# 主函数def main(argv=None): mnist = input_data.read_data_sets("MNIST_data", one_hot=True) train(mnist)# 当前的python文件是shell文件执行的入口文件,而非当做import的python module。if __name__ == '__main__': # 在模块内部执行 tf.app.run() # 调用main函数并传入所需的参数list
阅读全文
0 0
- MNIST数字识别问题
- MNIST数字识别问题(Tensorflow)
- MNIST手写数字识别
- MNIST数字识别
- MNIST数字识别代码
- TensorFlow学习_(4)MNIST数字识别问题
- cnn 手写数字识别 mnist
- tensorflow-mnist手写数字识别
- TF-day3 mnist识别数字
- MNIST识别数字(TensorFlow框架)
- Keras 浅尝之MNIST手写数字识别
- 用KNN做手写数字识别(mnist)
- 基于tensorflow的MNIST手写数字识别
- 3.1手写数字识别:MNIST,Perceptron
- 基于tensorflow的MNIST手写数字识别
- Tensorflow 实现 MNIST 手写数字识别
- Tensorflow框架下实现Mnist数字识别
- Tensorflow-MNIST数字识别练习代码
- Codeforces 455D Serega and Fun【解法二】
- 最大熵模型与EM算法
- 30天自制操作系统(第02天)–汇编语言学习和MakeFile入门
- HDU 4694 Important Sisters【支配树】
- 集训记录第四天关于文件
- MNIST数字识别问题
- css位置尺寸宽高(width,clientWidth,offsetWidth)
- Kingdoms UVA
- VS2013/MFC入门之二(应用程序工程中文件的组成结构)
- 线程类的run()与start()方法区别
- OkHttp实现分析之Websocket
- 如何在linux下写makefile
- Mysql5.6设置远程连接
- 数组溢出问题