tensorflow教程学习三TensorFlow运作方式入门
来源:互联网 发布:被冒名网络贷款 编辑:程序博客网 时间:2024/05/29 03:24
讲解链接:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_tf.html
"""Trains and Evaluates the MNIST network using a feed dictionary."""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_function# pylint: disable=missing-docstringimport argparseimport osimport sysimport timefrom six.moves import xrange # pylint: disable=redefined-builtinimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datafrom tensorflow.examples.tutorials.mnist import mnistFLAGS = None#placeholder_inputs()函数将生成两个tf.placeholder操作,定义传入图表中的shape参数,#shape参数中包括batch_size值,后续还会将实际的训练用例传入图表。#在训练循环(training loop)的后续步骤中,传入的整个图像和标签数据集会被切片,#以符合每一个操作所设置的batch_size值,占位符操作将会填补以符合这个batch_size值。#然后使用feed_dict参数,将数据传入sess.run()函数。def placeholder_inputs(batch_size): images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,mnist.IMAGE_PIXELS)) labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size)) return images_placeholder, labels_placeholderdef fill_feed_dict(data_set, images_pl, labels_pl): images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,FLAGS.fake_data) feed_dict = { images_pl: images_feed, labels_pl: labels_feed, } return feed_dictdef do_eval(sess,eval_correct,images_placeholder,labels_placeholder,data_set): # And run one epoch of eval. true_count = 0 # Counts the number of correct predictions. steps_per_epoch = data_set.num_examples // FLAGS.batch_size num_examples = steps_per_epoch * FLAGS.batch_size for step in xrange(steps_per_epoch): feed_dict = fill_feed_dict(data_set, images_placeholder, labels_placeholder) true_count += sess.run(eval_correct, feed_dict=feed_dict) precision = float(true_count) / num_examples print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' % (num_examples, true_count, precision))def run_training():#在run_training()方法的一开始,input_data.read_data_sets()函数会确保你的本地训练文件夹中,#已经下载了正确的数据,然后将这些数据解压并返回一个含有DataSet实例的字典。 data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data) # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Generate placeholders for the images and labels. images_placeholder, labels_placeholder = placeholder_inputs( FLAGS.batch_size) # Build a Graph that computes predictions from the inference model. logits = mnist.inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2) # Add to the Graph the Ops for loss calculation. loss = mnist.loss(logits, labels_placeholder) # Add to the Graph the Ops that calculate and apply gradients. train_op = mnist.training(loss, FLAGS.learning_rate) # Add the Op to compare the logits to the labels during evaluation. eval_correct = mnist.evaluation(logits, labels_placeholder) # Build the summary Tensor based on the TF collection of Summaries. summary = tf.summary.merge_all() # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a saver for writing training checkpoints. saver = tf.train.Saver() sess = tf.Session() summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph) sess.run(init) # Start the training loop. for step in xrange(FLAGS.max_steps): start_time = time.time() feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder) _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time # Write the summaries and print an overview fairly often. if step % 100 == 0: # Print status to stdout. print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) # Update the events file. summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the training set. print('Training Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.train) # Evaluate against the validation set. print('Validation Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.validation) # Evaluate against the test set. print('Test Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test)def main(_): if tf.gfile.Exists(FLAGS.log_dir): tf.gfile.DeleteRecursively(FLAGS.log_dir) tf.gfile.MakeDirs(FLAGS.log_dir) run_training()if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--learning_rate', type=float, default=0.01, help='Initial learning rate.' ) parser.add_argument( '--max_steps', type=int, default=2000, help='Number of steps to run trainer.' ) parser.add_argument( '--hidden1', type=int, default=128, help='Number of units in hidden layer 1.' ) parser.add_argument( '--hidden2', type=int, default=32, help='Number of units in hidden layer 2.' ) parser.add_argument( '--batch_size', type=int, default=100, help='Batch size. Must divide evenly into the dataset sizes.' ) parser.add_argument( '--input_data_dir', type=str, default=os.path.join(os.getenv('TEST_TMPDIR', 'tmp'), 'tensorflow/mnist/input_data'), help='Directory to put the input data.' ) parser.add_argument( '--log_dir', type=str, default=os.path.join(os.getenv('TEST_TMPDIR', 'tmp'), 'tensorflow/mnist/logs/fully_connected_feed'), help='Directory to put the log data.' ) parser.add_argument( '--fake_data', default=False, help='If true, uses fake data for unit testing.', action='store_true' ) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
附一张运行结果图:
构建图表 (Build the Graph)
在为数据创建占位符之后,就可以运行mnist.py文件,经过三阶段的模式函数操作:inference(), loss(),和training()。图表就构建完成了。
1.inference() —— 尽可能地构建好图表,满足促使神经网络向前反馈并做出预测的要求。
2.loss() —— 往inference图表中添加生成损失(loss)所需要的操作(ops)。
3.training() —— 往损失图表中添加计算并应用梯度(gradients)所需的操作。
训练模型
一旦图表构建完毕,就通过fully_connected_feed.py文件中的用户代码进行循环地迭代式训练和评估。
图表
在run_training()这个函数的一开始,是一个Python语言中的with命令,这个命令表明所有已经构建的操作都要与默认的tf.Graph全局实例关联起来。
会话
完成全部的构建准备、生成全部所需的操作之后,我们就可以创建一个tf.Session,用于运行图表。
训练循环
完成会话中变量的初始化之后,就可以开始训练了。
向图表提供反馈
执行每一步时,我们的代码会生成一个反馈字典(feed dictionary),其中包含对应步骤中训练所要使用的例子,这些例子的哈希键就是其所代表的占位符操作。
检查状态
在运行sess.run函数时,要在代码中明确其需要获取的两个值:[train_op, loss]。
状态可视化
为了释放TensorBoard所使用的事件文件(events file),所有的即时数据(在这里只有一个)都要在图表构建阶段合并至一个操作(op)中。
评估模型
每隔一千个训练步骤,我们的代码会尝试使用训练数据集与测试数据集,对模型进行评估。do_eval函数会被调用三次,分别使用训练数据集、验证数据集合测试数据集。
构建评估图表(Eval Graph)
在打开默认图表(Graph)之前,我们应该先调用get_data(train=False)函数,抓取测试数据集。
评估图表的输出(Eval Output)
之后,我们可以创建一个循环,往其中添加feed_dict,并在调用sess.run()函数时传入eval_correct操作,目的就是用给定的数据集评估模型。
- tensorflow教程学习三TensorFlow运作方式入门
- Tensorflow教程-TensorFlow运作方式入门
- TensorFlow运作方式入门
- TensorFlow运作方式入门
- tensorflow运作方式入门
- TensorFlow官方文档学习|TensorFlow运作方式入门
- Tensorflow 学习笔记 (2)官方文档学习 tensorflow运作方式入门
- TensorFlow学习笔记(四)——TensorFlow运作方式入门、可视化
- Tensorflow运作方式-综述
- TensorsFlow学习笔记5----TensorFlow Mechanics 101基本运作方式
- Tensorflow教程-MNIST机器学习入门
- tensorflow教程学习三深入MNIST
- TensorFlow 教程入门
- tensorflow的基本运作
- Tensorflow运作之变量
- TensorFlow入门很好的教程:你好,TensorFlow!
- TensorFlow入门很好的教程:你好,TensorFlow!
- TensorFlow官方教程学习笔记(三)——MNIST入门(续)
- 关于xcode8.0以上项目运行在低版本The document “Main.storyboard” requires Xcode 8.0 or later.
- EL表达式小总
- 音乐类Demo资源大全
- 竖排文本控件
- 图结构练习——最短路径(Dijkstra算法)
- tensorflow教程学习三TensorFlow运作方式入门
- css 元素水平居中
- 新建String对象小知识点
- Android之FFmpeg(3)--添加为视频添加背景音乐
- Cocos-JS 音效
- XYNU OJ 1103: 例题6-5 求矩阵最大值
- hdu1598
- 自动化无人零售店监管系统开发
- Python学习笔记