TensorFlow的MNIST数据识别
来源:互联网 发布:javascript var 对象 编辑:程序博客网 时间:2024/06/04 18:24
1、读取数据
2、训练模型
3、完整样例
import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dataimport syssys.path.append(r'G:\MNIST最佳实践')import mnist_inferencemnist = input_data.read_data_sets(r'G:\0tensorflow\MNIST_data', one_hot=True)BATCH_SIZE = 100LEARNING_RATE_BASE = 0.8LEARNING_RATE_DECAY = 0.99REGULARIZATION_RATE = 0.0001TRAINING_STEPS = 3000MOVING_AVERAGE_DECAY = 0.99def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2): # 不使用滑动平均类 if avg_class == None: layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1) return tf.matmul(layer1, weights2) + biases2 else: # 使用滑动平均类 layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1)) return tf.matmul(layer1, avg_class.average(weights2)) + avg_class.average(biases2) def train(mnist): # 输入数据的命名空间。 with tf.name_scope('input'): x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) y = mnist_inference.inference(x, regularizer) global_step = tf.Variable(0, trainable=False) # 处理滑动平均的命名空间。 with tf.name_scope("moving_average"): variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step) variables_averages_op = variable_averages.apply(tf.trainable_variables()) # 计算损失函数的命名空间。 with tf.name_scope("loss_function"): cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1)) cross_entropy_mean = tf.reduce_mean(cross_entropy) loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses')) # 定义学习率、优化方法及每一轮执行训练的操作的命名空间。 with tf.name_scope("train_step"): learning_rate = tf.train.exponential_decay( LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY, staircase=True) train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step) with tf.control_dependencies([train_step, variables_averages_op]): train_op = tf.no_op(name='train') writer = tf.summary.FileWriter("/log/modified_mnist_train.log", tf.get_default_graph()) # 训练模型。 with tf.Session() as sess: tf.global_variables_initializer().run() for i in range(TRAINING_STEPS): xs, ys = mnist.train.next_batch(BATCH_SIZE) if i % 1000 == 0: # 配置运行时需要记录的信息。 run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) # 运行时记录运行信息的proto。 run_metadata = tf.RunMetadata() _, loss_value, step = sess.run( [train_op, loss, global_step], feed_dict={x: xs, y_: ys}, options=run_options, run_metadata=run_metadata) writer.add_run_metadata(run_metadata=run_metadata, tag=("tag%d" % i), global_step=i) print("After %d training step(s), loss on training batch is %g." % (step, loss_value)) else: _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys}) writer.close()def main(argv=None): mnist = input_data.read_data_sets("/datasets/MNIST_data", one_hot=True) train(mnist)if __name__ == '__main__': main()
知识点:
1、添加路径
import syssys.path.append()
2.定义辅助函数来计算前向传播结果,使用ReLU做为激活函数
def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2): # 不使用滑动平均类 if avg_class == None: layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1) return tf.matmul(layer1, weights2) + biases2 else: # 使用滑动平均类 layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1)) return tf.matmul(layer1, avg_class.average(weights2)) + avg_class.average(biases2)
3、获得前向传播算法 输出值
# 输入数据的命名空间。 with tf.name_scope('input'): x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)#L2正则化项 y = mnist_inference.inference(x, regularizer)#获得输出值 global_step = tf.Variable(0, trainable=False)
4.处理滑动平均的命名空间
# 处理滑动平均的命名空间。 with tf.name_scope("moving_average"): variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)#定义滑动平均的类 variables_averages_op = variable_averages.apply(tf.trainable_variables())#列表的变量都会被更新
4、变量管理
tf.get_variable():创建或者获取变量;
tf.variable_scope():用于上下文管理器,使tf.get_variable()不出现变量重复调用的错误。
阅读全文
0 0
- TensorFlow的MNIST数据识别
- tensorflow下对MNIST数据集进行识别的程序代码
- tensorflow实现MNIST数据集识别
- tensorflow mnist数据集手写字识别
- 基于tensorflow的MNIST手写数字识别
- 基于tensorflow的MNIST手写数字识别
- MNIST手写字识别的TensorFlow实现
- 基于tensorflow的MNIST数字识别
- 基于tensorflow的MNIST手写字识别
- tensorflow的hellow world:mnist手写识别
- Tensorflow , MNIST 识别你自己手写的数字
- Tensorflow学习:MNIST 识别
- Tensorflow MNIST 手写识别
- Tensorflow手写体识别mnist
- MNIST手写体识别--tensorflow
- 利用tensorflow一步一步实现基于MNIST 数据集进行手写数字识别的神经网络,逻辑回归
- 使用tensorflow对Mnist数据集进行字体识别
- Tensorflow入门-简单神经网络进行MNIST数据集识别
- ListBox实现上移,下移,左移,右移操作
- 2017/12/20
- SpringBoot中使用AOP 监控sql耗时
- 关于bin/storm nimbus >/dev/null 2>&1 &
- windows10上使用apache部署python flask webapp
- TensorFlow的MNIST数据识别
- 结构体学习1
- 无头单链表
- 面向对象程序设计上机练习十二(运算符重载)
- 基于java实现发短信的功能
- 171220—原码、反码、补码
- C#导入Xml文件到Sqlserver
- JNI 之Java和c/c++交互,提升Java变成效率
- JS内置对象