【TensorFlow】MNIST(代码重构+模型持久化)

来源:互联网 发布:js 屏蔽运营商广告 编辑:程序博客网 时间:2024/06/07 09:07

项目已上传至 GitHub —— best

下载MNIST数据集


官方下载地址(可能需要梯子)

http://yann.lecun.com/exdb/mnist/

这里提供了百度网盘的下载地址,需要的自取

链接: https://pan.baidu.com/s/1geOcXxT 密码: mws8

下载之后将其放在 mnist/data/ 文件夹下,目录结构如下

mnist/    data/        train-images-idx3-ubyte.gz        train-labels-idx1-ubyte.gz        t10k-images-idx3-ubyte.gz        t10k-labels-idx1-ubyte.gz

代码重构


为了使代码有更好的可读性和扩展性,需要将之按功能分为不同的模块,并将可重用的代码抽象成库函数

所以可以把以前臃肿的 MNIST 代码分成三个模块

  • inference
  • train
  • eval

具体的文件夹目录如下

mnist/    data/        ......    best/        inference.py        train.py        eval.py

完整代码


代码实现自《TensorFlow:实战Google深度学习框架》

首先是 inference.py ,这个库函数负责模型训练及测试的前向传播过程

import tensorflow as tf# 定义神经网络相关参数INPUT_NODE = 784OUTPUT_NODE = 10LAYER1_NODE = 500# 创建权重变量,并加入正则化损失集合def get_weight_variable(shape, regularizer):    weights = tf.get_variable(        'weights',        shape,        initializer=tf.truncated_normal_initializer(stddev=0.1))    if regularizer != None:        tf.add_to_collection('losses', regularizer(weights))    return weights# 前向传播def inference(input_tensor, regularizer):    # 声明隐藏层的变量并进行前向传播    with tf.variable_scope('layer1'):        weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)        biases = tf.get_variable(            'biases', [LAYER1_NODE], initializer=tf.constant_initializer(0.0))        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)    # 声明输出层的变量并进行前向传播    with tf.variable_scope('layer2'):        weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)        biases = tf.get_variable(            'biases', [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))        layer2 = tf.matmul(layer1, weights) + biases    return layer2

然后是 train.py ,训练模型的模块

import osimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport inference# 优化方法参数LEARNING_RATE_BASE = 0.8  # 基础学习率LEARNING_RATE_DECAY = 0.99  # 学习率的衰减率REGULARIZATION_RATE = 0.0001  # 正则化项在损失函数中的系数MOVING_AVERAGE_DECAY = 0.99  # 滑动平均衰减率# 训练参数BATCH_SIZE = 100  # 一个训练batch中的图片数TRAINING_STEPS = 30000  # 训练轮数# 模型保存的路径和文件名MODEL_SAVE_PATH = 'model/'MODEL_NAME = 'mnist.ckpt'def train(mnist):    # 实现模型    x = tf.placeholder(        tf.float32, [None, inference.INPUT_NODE], name='x-input')  # 输入层    y_ = tf.placeholder(        tf.float32, [None, inference.OUTPUT_NODE], name='y-input')  # 标签    regularizer = tf.contrib.layers.l2_regularizer(        REGULARIZATION_RATE)  # 定义L2正则化损失函数    y = inference.inference(x, regularizer)  # 输出层    # 存储训练轮数,设置为不可训练    global_step = tf.Variable(0, trainable=False)    # 设置滑动平均方法    variable_averages = tf.train.ExponentialMovingAverage(        MOVING_AVERAGE_DECAY, global_step)  # 定义滑动平均类    variable_averages_op = variable_averages.apply(        tf.trainable_variables())  # 在所有可训练的变量上使用滑动平均值    # 设置指数衰减法    learning_rate = tf.train.exponential_decay(        LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE,        LEARNING_RATE_DECAY)    # 最小化损失函数    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(        logits=y, labels=tf.argmax(y_, 1))  # 计算每张图片的交叉熵    cross_entropy_mean = tf.reduce_mean(cross_entropy)  # 计算当前batch中所有图片的交叉熵平均值    loss = cross_entropy_mean + tf.add_n(        tf.get_collection('losses'))  # 总损失等于交叉熵损失和正则化损失的和    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(        loss, global_step=global_step)  # 优化损失函数    # 同时反向传播和滑动平均    with tf.control_dependencies([train_step, variable_averages_op]):        train_op = tf.no_op(name='train')    # 初始化持久化类    saver = tf.train.Saver()    # 开始训练    with tf.Session() as sess:        # 初始化所有变量        tf.global_variables_initializer().run()        # 迭代训练        for i in range(TRAINING_STEPS):            # 产生该轮batch            xs, ys = mnist.train.next_batch(BATCH_SIZE)            _, loss_value, step = sess.run(                [train_op, loss, global_step], feed_dict={                    x: xs,                    y_: ys                })            # 每1000轮保存一次模型            if i % 1000 == 0:                # 输出训练情况                print('After %d training steps, loss is %g.' % (step,                                                                loss_value))                # 保存当前模型                saver.save(                    sess,                    os.path.join(MODEL_SAVE_PATH, MODEL_NAME),                    global_step=global_step)# 主程序入口def main(argv=None):    mnist = input_data.read_data_sets('../data/', one_hot=True)    train(mnist)if __name__ == '__main__':    tf.app.run()

最后是 eval.py ,可以在训练模型的同时,每隔一段时间利用最新保存的模型进行测试

import timeimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport inferenceimport train# 每10秒加载一次最新的模型,并在测试数据上测试最新模型的正确率EVAL_INTERVAL_SECS = 10def evaluate(mnist):    with tf.Graph().as_default() as g:        # 定义输入输出的格式        x = tf.placeholder(            tf.float32, [None, inference.INPUT_NODE], name='x-input')        y_ = tf.placeholder(            tf.float32, [None, inference.OUTPUT_NODE], name='y-input')        y = inference.inference(x, None)        # 验证集        validate_feed = {            x: mnist.validation.images,            y_: mnist.validation.labels        }        # 评估模型        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))        # 通过变量重命名方式加载模型,获取滑动平均值        variable_averages = tf.train.ExponentialMovingAverage(            train.MOVING_AVERAGE_DECAY)        variables_to_restore = variable_averages.variables_to_restore()        saver = tf.train.Saver(variables_to_restore)        # 每隔10秒检测正确率        while True:            with tf.Session() as sess:                ckpt = tf.train.get_checkpoint_state(train.MODEL_SAVE_PATH)                if ckpt and ckpt.model_checkpoint_path:                    # 加载模型                    saver.restore(sess, ckpt.model_checkpoint_path)                    # 通过文件名字获取该模型保存的轮数                    global_step = ckpt.model_checkpoint_path.split('/')[                        -1].split('-')[-1]                    # 验证并输出结果                    accuracy_score = sess.run(                        accuracy, feed_dict=validate_feed)                    print(                        'After %s training steps, validattion accuracy = %g' %                        (global_step, accuracy_score))                else:                    print('No checkpoint file found')                    return            time.sleep(EVAL_INTERVAL_SECS)def main(argv=None):    mnist = input_data.read_data_sets('../data/', one_hot=True)    evaluate(mnist)if __name__ == '__main__':    tf.app.run()

运行结果


train.py 训练模型的结果如下

$ python train.pyExtracting ../data/train-images-idx3-ubyte.gzExtracting ../data/train-labels-idx1-ubyte.gzExtracting ../data/t10k-images-idx3-ubyte.gzExtracting ../data/t10k-labels-idx1-ubyte.gzAfter 1 training steps, loss is 2.75381.After 1001 training steps, loss is 0.26364.After 2001 training steps, loss is 0.160792.After 3001 training steps, loss is 0.144208.After 4001 training steps, loss is 0.120926.After 5001 training steps, loss is 0.10708.After 6001 training steps, loss is 0.102106.......After 22001 training steps, loss is 0.0399828.After 23001 training steps, loss is 0.0408827.After 24001 training steps, loss is 0.0355409.After 25001 training steps, loss is 0.0378072.After 26001 training steps, loss is 0.0352473.After 27001 training steps, loss is 0.0357247.After 28001 training steps, loss is 0.0318179.After 29001 training steps, loss is 0.0417907.

eval.py 评估模型的结果如下

$ python eval.pyExtracting ../data/train-images-idx3-ubyte.gzExtracting ../data/train-labels-idx1-ubyte.gzExtracting ../data/t10k-images-idx3-ubyte.gzExtracting ../data/t10k-labels-idx1-ubyte.gzAfter 26001 training steps, validattion accuracy = 0.983After 28001 training steps, validattion accuracy = 0.985After 29001 training steps, validattion accuracy = 0.986......
原创粉丝点击