Tensorflow 03: 前向神经网络-MIST

来源:互联网 发布:机械加工工艺编程员 编辑:程序博客网 时间:2024/06/05 11:17

前言

今天主要从整体上学习了一下tensorflow的大致框架及写代码的一些标准规范。主要包括:
(1)前向神经网络的设计,并用于MNIST手写体数字的识别。网络结构如下:
这里写图片描述
(2)网络的模块化设计,从计算图的构造到训练;
(3)Tensorflow中log文件的写法,并用tensorboard进行可视化;
(4)Tensorflow中训练好的模型的保存;

代码

主程序如下所示,文件:fully_connected_feed.py。程序中的一些重要代码都已经加入中文注释,方便以后自已阅读查看。

# coding=utf-8"""使用前向神经网络训练 MNIST,并用softmax做分类器"""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_function# pylint: disable=missing-docstringimport argparseimport os.pathimport sysimport timefrom six.moves import xrange  # pylint: disable=redefined-builtinimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport mnist# Basic model parameters as external flags.# Basic model parameters as external flags.FLAGS = Nonedef placeholder_inputs(batch_size):    images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, mnist.IMAGE_PIXELS))    labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))    return images_placeholder, labels_placeholderdef fill_feed_dict(data_set, images_pl, labels_pl):    images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size, FLAGS.fake_data)    feed_dict = {images_pl: images_feed, labels_pl: labels_feed, }    return feed_dictdef do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_set):    """    模型评估    """    # And run one epoch of eval.    true_count = 0  # Counts the number of correct predictions.    steps_per_epoch = data_set.num_examples // FLAGS.batch_size    num_examples = steps_per_epoch * FLAGS.batch_size    for step in xrange(steps_per_epoch):        feed_dict = fill_feed_dict(data_set, images_placeholder, labels_placeholder)        true_count += sess.run(eval_correct, feed_dict=feed_dict)    precision = float(true_count) / num_examples    print('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' % (num_examples, true_count, precision))def run_training():    # 读取 MNIST 数据    data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)    # 在 Tensorflow 默认图中创建模型    with tf.Graph().as_default():        images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)        # 构建前向计算的图模型.        logits = mnist.inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2)        # 往Graph中添加loss function        loss = mnist.loss(logits, labels_placeholder)        # 往Graph中添加训练train_op ops        train_op = mnist.training(loss, FLAGS.learning_rate)        eval_correct = mnist.evaluation(logits, labels_placeholder)        # 合并默认图中的所有summary,并返回一个summary tensor        summary = tf.summary.merge_all()        # 创建模型保存对象 saver        saver = tf.train.Saver()        sess = tf.Session()        # 实例化一个 summary_writer 来输出summary和图模型        summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)        init = tf.global_variables_initializer()        sess.run(init)        # loop-训练        for step in xrange(FLAGS.max_steps):            start_time = time.time()            feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder)            _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)            duration = time.time() - start_time            # Write the summaries and print an overview fairly often.            if step % 100 == 0:                print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))                # Update the events file.                summary_str = sess.run(summary, feed_dict=feed_dict)                summary_writer.add_summary(summary_str, step)                # 写入磁盘                summary_writer.flush()            # 保存训练的模型文件 ckpt            if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:                checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')                saver.save(sess, checkpoint_file, global_step=step)                # 对训练好的模型进行评估-训练集                print('Training Data Eval:')                do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.train)                # 对训练好的模型进行评估-验证集                print('Validation Data Eval:')                do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.validation)                # 对训练好的模型进行评估-测试集                print('Test Data Eval:')                do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test)def main(_):    # 判断路径是否存在    if tf.gfile.Exists(FLAGS.log_dir):        # 删除路径下的所有文件        tf.gfile.DeleteRecursively(FLAGS.log_dir)    # 创建路径    tf.gfile.MakeDirs(FLAGS.log_dir)    run_training()if __name__ == '__main__':    parser = argparse.ArgumentParser()    parser.add_argument(        '--learning_rate',        type=float,        default=0.01,        help='Initial learning rate.'    )    parser.add_argument(        '--max_steps',        type=int,        default=2000,        help='Number of steps to run trainer.'    )    parser.add_argument(        '--hidden1',        type=int,        default=128,        help='Number of units in hidden layer 1.'    )    parser.add_argument(        '--hidden2',        type=int,        default=32,        help='Number of units in hidden layer 2.'    )    parser.add_argument(        '--batch_size',        type=int,        default=100,        help='Batch size.  Must divide evenly into the dataset sizes.'    )    parser.add_argument(        '--input_data_dir',        type=str,        default='MNIST_data',        help='Directory to put the input data.'    )    parser.add_argument(        '--log_dir',        type=str,        default='MNIST_Log',        help='Directory to put the log data.'    )    parser.add_argument(        '--fake_data',        default=False,        help='If true, uses fake data for unit testing.',        action='store_true'    )    FLAGS, unparsed = parser.parse_known_args()    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

上述主程序用到的模块文件mnist.py如下。重要功能也加入了中文注释:

# coding=utf-8"""构建 MNIST 网络,包括:(1)inference(2)loss(3)training"""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport mathimport tensorflow as tf# 参数NUM_CLASSES = 10IMAGE_SIZE = 28IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZEdef inference(images, hidden1_units, hidden2_units):    """    构建网络前向计算图    参数:        images: 输入的 images, placeholder.        hidden1_units: 隐层1神经元大小.        hidden2_units: 隐层2神经元大小.    返回值:        softmax_linear: logits.    """    # 隐层1    # tf.name_scope()函数返回一个context manager    with tf.name_scope('hidden1'):        weights = tf.Variable(tf.truncated_normal([IMAGE_PIXELS, hidden1_units],                                stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))), name='weights')        biases = tf.Variable(tf.zeros([hidden1_units]), name='biases')        hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)    # 隐层2    with tf.name_scope('hidden2'):        weights = tf.Variable(tf.truncated_normal([hidden1_units, hidden2_units],                                stddev=1.0 / math.sqrt(float(hidden1_units))), name='weights')        biases = tf.Variable(tf.zeros([hidden2_units]), name='biases')        hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)    # 输出    with tf.name_scope('softmax_linear'):        weights = tf.Variable(tf.truncated_normal([hidden2_units, NUM_CLASSES],                                stddev=1.0 / math.sqrt(float(hidden2_units))), name='weights')        biases = tf.Variable(tf.zeros([NUM_CLASSES]), name='biases')        logits = tf.matmul(hidden2, weights) + biases    return logitsdef loss(logits, labels):    """    计算网络的loss    参数:        logits: softmax_linear        labels: 真实的标签    返回值:        loss: 代价函数    """    labels = tf.to_int64(labels)    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, name='xentropy')    return tf.reduce_mean(cross_entropy, name='xentropy_mean')def training(loss, learning_rate):    """    构建训练 ops    """    # Add a scalar summary for the snapshot loss.    tf.summary.scalar('loss', loss)    optimizer = tf.train.GradientDescentOptimizer(learning_rate)    # Create a variable to track the global step.    global_step = tf.Variable(0, name='global_step', trainable=False)    train_op = optimizer.minimize(loss, global_step=global_step)    return train_opdef evaluation(logits, labels):    """    评估网络计算出来的logits的准确率    """    # For a classifier model, we can use the in_top_k Op.    # It returns a bool tensor with shape [batch_size] that is true for    # the examples where the label is in the top k (here k=1)    # of all logits for that example.    correct = tf.nn.in_top_k(logits, labels, 1)    # Return the number of true entries.    return tf.reduce_sum(tf.cast(correct, tf.int32))

重要模块介绍

(1)模型的保存。包括2步:
a.saver对象的创建: saver = tf.train.Saver()
b.调用save函数保存模型: saver.save(sess, checkpoint_file, global_step=step)

(2)summary的写法。用于tensorboard的可视化。
a.某种类型的summary的创建,如标量summary: tf.summary.scalar(‘loss’, loss)
b.合并计算图中用到的所用summary:  summary = tf.summary.merge_all()
c.实例化一个 summary_writer 来输出summary和图模型: summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
d.运行summary,并将结果写入到磁盘:
summary_str = sess.run(summary, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, step)
summary_writer.flush()

参考网址

https://www.tensorflow.org/get_started/mnist/mechanics  --- tensorflow官网教程

原创粉丝点击