CIFAR10 代码分析详解——cifar10_train.py

来源:互联网 发布:linux给其他用户权限 编辑:程序博客网 时间:2024/06/05 07:45

先在这里种个草,开篇后慢慢补完

引入各种库,并定义参数

from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionfrom datetime import datetimeimport timeimport tensorflow as tfimport cifar10FLAGS = tf.app.flags.FLAGStf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',                           """Directory where to write event logs """                           """and checkpoint.""")tf.app.flags.DEFINE_integer('max_steps', 1000000,                            """Number of batches to run.""")tf.app.flags.DEFINE_boolean('log_device_placement', False,                            """Whether to log device placement.""")tf.app.flags.DEFINE_integer('log_frequency', 10,                            """How often to log results to the console.""")

下面是训练函数主体

def train():  """Train CIFAR-10 for a number of steps."""#定义一个图,关于Graph的用法查链接  with tf.Graph().as_default():     #获取global_step,至于为什么这么用有待考证。    tf.contrib.framework.get_or_create_global_step(Graph)    #若无输入图,则为默认图    global_step = tf.contrib.framework.get_or_create_global_step()    # Get images and labels for CIFAR-10.    images, labels = cifar10.distorted_inputs()    # Build a Graph that computes the logits predictions from the    # inference model.    logits = cifar10.inference(images)    # Calculate loss.    loss = cifar10.loss(logits, labels)    # Build a Graph that trains the model with one batch of examples and    # updates the model parameters.    train_op = cifar10.train(loss, global_step)    #log部分以后再补充????    class _LoggerHook(tf.train.SessionRunHook):      """Logs loss and runtime."""      def begin(self):        self._step = -1        self._start_time = time.time()      def before_run(self, run_context):        self._step += 1        return tf.train.SessionRunArgs(loss)  # Asks for loss value.      def after_run(self, run_context, run_values):        if self._step % FLAGS.log_frequency == 0:          current_time = time.time()          duration = current_time - self._start_time          self._start_time = current_time          loss_value = run_values.results          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration          sec_per_batch = float(duration / FLAGS.log_frequency)          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '                        'sec/batch)')          print (format_str % (datetime.now(), self._step, loss_value,                               examples_per_sec, sec_per_batch))        #这里要找到stop criterion????    with tf.train.MonitoredTrainingSession(        checkpoint_dir=FLAGS.train_dir,        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),               tf.train.NanTensorHook(loss),               _LoggerHook()],        config=tf.ConfigProto(            log_device_placement=FLAGS.log_device_placement)) as mon_sess:      while not mon_sess.should_stop():        mon_sess.run(train_op)def main(argv=None):  # pylint: disable=unused-argument  cifar10.maybe_download_and_extract()  if tf.gfile.Exists(FLAGS.train_dir):    tf.gfile.DeleteRecursively(FLAGS.train_dir)  tf.gfile.MakeDirs(FLAGS.train_dir)  train()if __name__ == '__main__':  tf.app.run()

该部分代码比较简单,在主体函数 train() 中先通过 cifar10.distorted_input() 读取图像和标签,然后通过cifar10.inference() 进行 logits 的估计,通过cifar10.loss() 来计算损失,再创建一个 train_op=cifar10.train() 来进行模型训练参数更新,直到满足 stop criterion。调用的函数参见相应的文章。


0 0
原创粉丝点击