CIFAR10 代码分析详解——cifar10_train.py
来源:互联网 发布:linux给其他用户权限 编辑:程序博客网 时间:2024/06/05 07:45
先在这里种个草,开篇后慢慢补完
引入各种库,并定义参数
from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionfrom datetime import datetimeimport timeimport tensorflow as tfimport cifar10FLAGS = tf.app.flags.FLAGStf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train', """Directory where to write event logs """ """and checkpoint.""")tf.app.flags.DEFINE_integer('max_steps', 1000000, """Number of batches to run.""")tf.app.flags.DEFINE_boolean('log_device_placement', False, """Whether to log device placement.""")tf.app.flags.DEFINE_integer('log_frequency', 10, """How often to log results to the console.""")
下面是训练函数主体
def train(): """Train CIFAR-10 for a number of steps."""#定义一个图,关于Graph的用法查链接 with tf.Graph().as_default(): #获取global_step,至于为什么这么用有待考证。 tf.contrib.framework.get_or_create_global_step(Graph) #若无输入图,则为默认图 global_step = tf.contrib.framework.get_or_create_global_step() # Get images and labels for CIFAR-10. images, labels = cifar10.distorted_inputs() # Build a Graph that computes the logits predictions from the # inference model. logits = cifar10.inference(images) # Calculate loss. loss = cifar10.loss(logits, labels) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar10.train(loss, global_step) #log部分以后再补充???? class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) #这里要找到stop criterion???? with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook()], config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement)) as mon_sess: while not mon_sess.should_stop(): mon_sess.run(train_op)def main(argv=None): # pylint: disable=unused-argument cifar10.maybe_download_and_extract() if tf.gfile.Exists(FLAGS.train_dir): tf.gfile.DeleteRecursively(FLAGS.train_dir) tf.gfile.MakeDirs(FLAGS.train_dir) train()if __name__ == '__main__': tf.app.run()
该部分代码比较简单,在主体函数 train() 中先通过 cifar10.distorted_input() 读取图像和标签,然后通过cifar10.inference() 进行 logits 的估计,通过cifar10.loss() 来计算损失,再创建一个 train_op=cifar10.train() 来进行模型训练参数更新,直到满足 stop criterion。调用的函数参见相应的文章。
0 0
- CIFAR10 代码分析详解——cifar10_train.py
- theano-xnor-net代码注释 cifar10_train.py
- Tensorflow实现CIFAR-10分类问题-详解一cifar10_train.py
- tensorflow cifar10 代码详解
- TensorFlow-CIFAR10 CNN代码分析
- TensorFlow-CIFAR10 CNN代码分析
- theano-xnor-net代码注释9 pylearn2/cifar10.py
- Tensorflow实现CIFAR-10分类问题-详解二cifar10.py
- tensorflow/cifar10.py权重损失
- TensorFlow学习之CNN-Cifar10代码阅读与详解(一):cifar10数据批量读取
- TensorFlow中cnn-cifar10样例代码详解
- TensorFlow中cnn-cifar10样例输入部分代码详解
- TensorFlow中cnn-cifar10样例部分代码详解
- TensorFlow中cnn-cifar10样例代码详解
- python中BaseHTTPServer.py代码阅读分析
- python中BaseHTTPServer.py代码阅读分析
- 自主导航nav_test.py代码分析
- UDPClient.py&UDPServer.py——我的第一行Python代码
- 六角填数
- 剑指Offer(16)______合并两个排序的链表
- Tachyon简介及目前可用性分析
- 递推递归练习 H
- myeclipse 左侧空间报Could not create the view: 2
- CIFAR10 代码分析详解——cifar10_train.py
- 剑指Offer(17)______树的子结构
- Spark-0.8新增Fair Scheduler资源调度
- 使用AndroidStudio配置OpenCV
- 【学术篇】oj.jzxx.net2701 无根树
- 设计模式系列(4)迪米特法则形象解释
- 剑指Offer(18)______二叉树的镜像
- mysql 触发器,存储过程
- Activity的四种启动模式(standard、singleTop、singleTask、singleInstance)