TensorFlow MNIST LeNet 模型持久化
来源:互联网 发布:弱电网络模块 编辑:程序博客网 时间:2024/05/22 08:24
前向传播过程mnist_inference.py
import tensorflow as tf# 定义神经网络相关的参数INPUT_NODE = 784OUTPUT_NODE = 10def inference(inputs, dropout_keep_prob): x_image = tf.reshape(inputs, [-1, 28, 28, 1]) # 第一层:卷积层 conv1_weights = tf.get_variable("conv1_weights", [5, 5, 1, 32], initializer=tf.truncated_normal_initializer(stddev=0.1)) # 过滤器大小为5*5, 当前层深度为1, 过滤器的深度为32 conv1 = tf.nn.conv2d(x_image, filter=conv1_weights, strides=[1, 1, 1, 1], padding='SAME') # 移动步长为1, 使用全0填充 conv1_biases = tf.get_variable("conv1_biases", [32], initializer=tf.constant_initializer(0.0)) relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases)) # 激活函数Relu去线性化 # 第二层:最大池化层 # 池化层过滤器的大小为2*2, 移动步长为2,使用全0填充 pool1 = tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') #输出14*14*32 # 第三层:卷积层 conv2_weights = tf.get_variable("conv2_weights", [5, 5, 32, 64], initializer=tf.truncated_normal_initializer(stddev=0.1)) # 过滤器大小为5*5, 当前层深度为32, 过滤器的深度为64 conv2 = tf.nn.conv2d(pool1, conv2_weights, strides=[1, 1, 1, 1], padding='SAME') # 移动步长为1, 使用全0填充 conv2_biases = tf.get_variable("conv2_biases", [64], initializer=tf.constant_initializer(0.0)) relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases)) # 第四层:最大池化层 # 池化层过滤器的大小为2*2, 移动步长为2,使用全0填充 pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') #输出7*7*64 # 第五层:全连接层 pool2_vector = tf.reshape(pool2, [-1, 7 * 7 * 64]) fc1_weights = tf.get_variable("fc1_weights", [7 * 7 * 64, 1024], initializer=tf.truncated_normal_initializer(stddev=0.1)) # 7*7*64=3136把前一层的输出变成特征向量 fc1_baises = tf.get_variable("fc1_baises", [1024], initializer=tf.constant_initializer(0.1)) fc1 = tf.nn.relu(tf.matmul(pool2_vector, fc1_weights) + fc1_baises) # 为了减少过拟合,加入Dropout层 fc1_dropout = tf.nn.dropout(fc1, dropout_keep_prob) # 第六层:全连接层 fc2_weights = tf.get_variable("fc2_weights", [1024, 10], initializer=tf.truncated_normal_initializer(stddev=0.1)) # 神经元节点数1024, 分类节点10 fc2_biases = tf.get_variable("fc2_biases", [10], initializer=tf.constant_initializer(0.1)) fc2 = tf.matmul(fc1_dropout, fc2_weights) + fc2_biases return fc2
训练mnist_train.py
import osimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport mnist_inference#BATCH_SIZE = 100#学习率LEARN_RATE = 0.001MODEL_SAVE_PATH = "model/"MODEL_NAME = "model.ckpt"EPOCH = 2def train(mnist): inputs = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE]) labels = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE]) dropout_keep_prob = tf.placeholder(tf.float32) logits = mnist_inference.inference(inputs, dropout_keep_prob) global_step = tf.Variable(0, trainable=False) cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels) #tf.nn.sparse_softmax_cross_entropy_with_logits cost = tf.reduce_mean(cross_entropy) train_op = tf.train.AdamOptimizer(LEARN_RATE).minimize(cost, global_step=global_step) saver = tf.train.Saver() with tf.Session() as sess: tf.global_variables_initializer().run() print(mnist.train.images.shape) for i in range(20000): batch_inputs, batch_labels = mnist.train.next_batch(BATCH_SIZE) _, cost_value, step = sess.run([train_op, cost, global_step], feed_dict={inputs: batch_inputs, labels: batch_labels, dropout_keep_prob:0.5}) if i % 1000 == 0: print("After %d training step(s), loss on training batch is %f." % (step, cost_value)) saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)def main(argv=None): mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) train(mnist)if __name__ == '__main__': tf.app.run()
评估mnis_eval.py
import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport mnist_inferenceimport mnist_traindef evaluate(mnist): inputs = tf.placeholder(tf.float32, [None, 784]) labels = tf.placeholder(tf.float32, [None, 10]) dropout_keep_prob = tf.placeholder(tf.float32) logits = mnist_inference.inference(inputs, dropout_keep_prob) print(logits) correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict={inputs: mnist.test.images, labels: mnist.test.labels, dropout_keep_prob:1.0}) print("After %s training step(s), validation accuracy = %f" % (global_step, accuracy_score)) else: print("No checkpoint file found") returndef main(argv=None): mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) evaluate(mnist)if __name__ == '__main__': tf.app.run()
阅读全文
0 0
- TensorFlow MNIST LeNet 模型持久化
- 【TensorFlow】MNIST(代码重构+模型持久化)
- 基于tensorflow和mnist的LeNet-5模型实现
- tensorflow 卷积神经网络 LeNet-5模型 MNIST手写体数字识别
- Tensorflow 模型持久化
- tensorflow--模型持久化
- 81、Tensorflow实现LeNet-5模型,多层卷积层,识别mnist数据集
- LeNet-5结构写Mnist识别(Tensorflow)
- Tensorflow中的模型持久化
- tensorflow 模型的持久化
- Tensorflow基础:模型持久化
- 5.4 TensorFlow模型持久化
- tensorflow实现LeNet-5模型
- 【TensorFlow】MNIST(使用LeNet5+滑动平均+正则化+指数衰减法+激活函数+模型持久化)
- Tensorflow模型持久化与恢复
- TensorFlow模型的保存和持久化
- Tensorflow模型持久化的代码实现
- TensorFlow MNIST CNN LeNet5模型
- Java并发编程:线程池的使用
- WPF简单教程:控件ViewBox
- Android RadioButton设置选中时文字和背景颜色同时改变
- Ajax认识
- Android Service和Thread的区别
- TensorFlow MNIST LeNet 模型持久化
- Java数组的遍历与求和
- util.God -2
- 阿里物联网套件-服务端SDK学习实践(基础篇-1准备)
- cookie跨域session共享
- 用两个栈实现队列
- (149)环境立方体贴图
- lintcode(139)最接近零的子数组和
- 一张图理清楚关系型/非关系型数据库与Elasticsearch同步