RESNET学习笔记(二)

来源:互联网 发布:什么手机支持双频网络 编辑:程序博客网 时间:2024/05/22 07:49

第二部分分解的是训练resnet的函数resnet_train.py,该函数的作用是训练resnet。

from resnet import * import tensorflow as tfMOMENTUM = 0.9  #momentum系数FLAGS = tf.app.flags.FLAGS  #定义神经网络所需要的常数tf.app.flags.DEFINE_string('train_dir', '/tmp/resnet_train',                           """Directory where to write event logs """                           """and checkpoint.""")tf.app.flags.DEFINE_float('learning_rate', 0.01, "learning rate.")tf.app.flags.DEFINE_integer('batch_size', 10, "batch size")tf.app.flags.DEFINE_integer('max_steps', 500, "max steps")tf.app.flags.DEFINE_boolean('resume', False,                            'resume from latest saved state')tf.app.flags.DEFINE_boolean('minimal_summaries', True,                            'produce fewer summaries to save HD space')#定义top_k_error函数,此函数的作用是观察样本的正确分类标签是否在topk内def top_k_error(predictions, labels, k):    batch_size = float(FLAGS.batch_size) #tf.shape(predictions)[0]    in_top1 = tf.to_float(tf.nn.in_top_k(predictions, labels, k=1)) #此处选择的是top1函数    num_correct = tf.reduce_sum(in_top1) #计算in_top1的和    return (batch_size - num_correct) / batch_size #返回正确率#训练函数def train(is_training, logits, images, labels):    global_step = tf.get_variable('global_step', [],                               initializer=tf.constant_initializer(0),                                  trainable=False)#创建global_step参数    val_step = tf.get_variable('val_step', [],                                  initializer=tf.constant_initializer(0),                                  trainable=False)    loss_ = loss(logits, labels) #计算loss    predictions = tf.nn.softmax(logits) #通过softmax后的预测值    top1_error = top_k_error(predictions, labels, 1)    # loss_avg    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step) #定义一个滑动平均的类    tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss_])) #对loss_进行滑动平均    tf.summary.scalar('loss_avg', ema.average(loss_))    # validation stats    ema = tf.train.ExponentialMovingAverage(0.9, val_step)    val_op = tf.group(val_step.assign_add(1), ema.apply([top1_error])) #进行滑动平均    top1_error_avg = ema.average(top1_error) #取滑动平均后的error    tf.summary.scalar('val_top1_error_avg', top1_error_avg)    tf.summary.scalar('learning_rate', FLAGS.learning_rate)    opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, MOMENTUM) #初始化一个momentum优化器    grads = opt.compute_gradients(loss_) #用此优化器对loss进行梯度计算    #将各种变量加入summary    for grad, var in grads:        if grad is not None and not FLAGS.minimal_summaries:            tf.histogram_summary(var.op.name + '/gradients', grad)    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)    if not FLAGS.minimal_summaries:        # Display the training images in the visualizer.        tf.image_summary('images', images)        for var in tf.trainable_variables():            tf.histogram_summary(var.op.name, var)    batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)    batchnorm_updates_op = tf.group(*batchnorm_updates)    train_op = tf.group(apply_gradient_op, batchnorm_updates_op) #将上述两种操作组合起来    saver = tf.train.Saver(tf.global_variables())    summary_op =  tf.summary.merge_all()    init = tf.initialize_all_variables()    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))    sess.run(init) #执行程序    tf.train.start_queue_runners(sess=sess) #创建多个线程    summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)    if FLAGS.resume:        latest = tf.train.latest_checkpoint(FLAGS.train_dir)        if not latest:            print "No checkpoint to continue from in", FLAGS.train_dir            sys.exit(1)        print "resume", latest        saver.restore(sess, latest)    for x in xrange(FLAGS.max_steps + 1):        start_time = time.time() #开始时间        step = sess.run(global_step) #获得step        i = [train_op, loss_]         write_summary = step % 100 and step > 1        if write_summary:            i.append(summary_op)        o = sess.run(i, { is_training: True }) #对梯度,loss进行训练        loss_value = o[1] #取出loss_value        duration = time.time() - start_time #执行程序所花费的时间        assert not np.isnan(loss_value), 'Model diverged with loss = NaN'        if step % 5 == 0:            examples_per_sec = FLAGS.batch_size / float(duration) #每个样本所需要的时间            format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f '                          'sec/batch)')            print(format_str % (step, loss_value, examples_per_sec, duration))        if write_summary:            summary_str = o[2]            summary_writer.add_summary(summary_str, step)        # Save the model checkpoint periodically.        if step > 1 and step % 10 == 0:            checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')            saver.save(sess, checkpoint_path, global_step=global_step) #制作一个ckpt        # Run validation periodically        if step > 1 and step % 10 == 0:            _, top1_error_value = sess.run([val_op, top1_error], { is_training: False })            print('Validation top1 error %.2f' % top1_error_value) #显示当前top1error率
原创粉丝点击