RESNET学习笔记(二)
来源:互联网 发布:什么手机支持双频网络 编辑:程序博客网 时间:2024/05/22 07:49
第二部分分解的是训练resnet的函数resnet_train.py,该函数的作用是训练resnet。
from resnet import * import tensorflow as tfMOMENTUM = 0.9 #momentum系数FLAGS = tf.app.flags.FLAGS #定义神经网络所需要的常数tf.app.flags.DEFINE_string('train_dir', '/tmp/resnet_train', """Directory where to write event logs """ """and checkpoint.""")tf.app.flags.DEFINE_float('learning_rate', 0.01, "learning rate.")tf.app.flags.DEFINE_integer('batch_size', 10, "batch size")tf.app.flags.DEFINE_integer('max_steps', 500, "max steps")tf.app.flags.DEFINE_boolean('resume', False, 'resume from latest saved state')tf.app.flags.DEFINE_boolean('minimal_summaries', True, 'produce fewer summaries to save HD space')#定义top_k_error函数,此函数的作用是观察样本的正确分类标签是否在topk内def top_k_error(predictions, labels, k): batch_size = float(FLAGS.batch_size) #tf.shape(predictions)[0] in_top1 = tf.to_float(tf.nn.in_top_k(predictions, labels, k=1)) #此处选择的是top1函数 num_correct = tf.reduce_sum(in_top1) #计算in_top1的和 return (batch_size - num_correct) / batch_size #返回正确率#训练函数def train(is_training, logits, images, labels): global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)#创建global_step参数 val_step = tf.get_variable('val_step', [], initializer=tf.constant_initializer(0), trainable=False) loss_ = loss(logits, labels) #计算loss predictions = tf.nn.softmax(logits) #通过softmax后的预测值 top1_error = top_k_error(predictions, labels, 1) # loss_avg ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step) #定义一个滑动平均的类 tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss_])) #对loss_进行滑动平均 tf.summary.scalar('loss_avg', ema.average(loss_)) # validation stats ema = tf.train.ExponentialMovingAverage(0.9, val_step) val_op = tf.group(val_step.assign_add(1), ema.apply([top1_error])) #进行滑动平均 top1_error_avg = ema.average(top1_error) #取滑动平均后的error tf.summary.scalar('val_top1_error_avg', top1_error_avg) tf.summary.scalar('learning_rate', FLAGS.learning_rate) opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, MOMENTUM) #初始化一个momentum优化器 grads = opt.compute_gradients(loss_) #用此优化器对loss进行梯度计算 #将各种变量加入summary for grad, var in grads: if grad is not None and not FLAGS.minimal_summaries: tf.histogram_summary(var.op.name + '/gradients', grad) apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) if not FLAGS.minimal_summaries: # Display the training images in the visualizer. tf.image_summary('images', images) for var in tf.trainable_variables(): tf.histogram_summary(var.op.name, var) batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION) batchnorm_updates_op = tf.group(*batchnorm_updates) train_op = tf.group(apply_gradient_op, batchnorm_updates_op) #将上述两种操作组合起来 saver = tf.train.Saver(tf.global_variables()) summary_op = tf.summary.merge_all() init = tf.initialize_all_variables() sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) sess.run(init) #执行程序 tf.train.start_queue_runners(sess=sess) #创建多个线程 summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph) if FLAGS.resume: latest = tf.train.latest_checkpoint(FLAGS.train_dir) if not latest: print "No checkpoint to continue from in", FLAGS.train_dir sys.exit(1) print "resume", latest saver.restore(sess, latest) for x in xrange(FLAGS.max_steps + 1): start_time = time.time() #开始时间 step = sess.run(global_step) #获得step i = [train_op, loss_] write_summary = step % 100 and step > 1 if write_summary: i.append(summary_op) o = sess.run(i, { is_training: True }) #对梯度,loss进行训练 loss_value = o[1] #取出loss_value duration = time.time() - start_time #执行程序所花费的时间 assert not np.isnan(loss_value), 'Model diverged with loss = NaN' if step % 5 == 0: examples_per_sec = FLAGS.batch_size / float(duration) #每个样本所需要的时间 format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (step, loss_value, examples_per_sec, duration)) if write_summary: summary_str = o[2] summary_writer.add_summary(summary_str, step) # Save the model checkpoint periodically. if step > 1 and step % 10 == 0: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=global_step) #制作一个ckpt # Run validation periodically if step > 1 and step % 10 == 0: _, top1_error_value = sess.run([val_op, top1_error], { is_training: False }) print('Validation top1 error %.2f' % top1_error_value) #显示当前top1error率
阅读全文
0 0
- RESNET学习笔记(二)
- Resnet学习笔记(一)--resnet.py
- Resnet学习笔记(三)--image_processing.py
- Resnet学习笔记(四)--train_imagenet.py
- ResNet学习笔记
- 学习笔记TF033:实现ResNet
- caffe学习笔记27-ResNet论文笔记
- 学习笔记:inception V4 与resnet
- ResNet学习
- 系统学习深度学习(二十)--ResNet,DenseNet,以及残差家族
- 系统学习深度学习(二十一)--GoogLeNetV4与Inception-ResNet V1,V2
- 【深度学习】入门理解ResNet和他的小姨子们(二)---DenseNet
- 深度学习笔记(5)——学术界的霸主Resnet
- 菜鸡的学习笔记(一):DeepLab-ResNet Model代码中的相关知识点
- (DeepLab-resnet) + 深度学习部份层 小笔记。
- ResNet论文笔记
- resNet论文笔记
- ResNet论文笔记
- unity3d 获取使用内存大小 android and ios
- 聚合支付”为什么很多游戏商家选择他
- KEIL5,STM32工程建立注意事项
- Wake Your Computer Up From Local Network
- Tomcat 部署项目的三种方法
- RESNET学习笔记(二)
- okhttp封装类
- 数据可视化D3-简单说
- 漫谈 Clustering (5): Hierarchical Clustering
- 关于焦点
- bzoj3450 Tyvj1952 Easy
- python学习(一)
- SpringApplication 的运行过程分析: run()
- vue2.0动态改变index中title