py-faster-rcnn源码解读系列(三)——train.py
来源:互联网 发布:双色球算法公式技巧 编辑:程序博客网 时间:2024/06/04 19:24
这是一个简单的solver包装类,主要是为了实现自己的snapshot,值得一提的地方不是太多,主要是为了读者从头到尾的训练过程理解更加连贯,所以为此文单独开一节源码分析。
class SolverWrapper(object):"""A simple wrapper around Caffe's solver.This wrapper gives us control over he snapshotting process, which weuse to unnormalize the learned bounding-box regression weights."""def __init__(self, solver_prototxt, roidb, output_dir,pretrained_model=None): """Initialize the SolverWrapper.""" self.output_dir = output_dir if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and cfg.TRAIN.BBOX_NORMALIZE_TARGETS): # RPN can only use precomputed normalization because there are no # fixed statistics to compute a priori assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED if cfg.TRAIN.BBOX_REG: print 'Computing bounding-box regression targets...' self.bbox_means, self.bbox_stds = \ rdl_roidb.add_bbox_regression_targets(roidb) print 'done' self.solver = caffe.SGDSolver(solver_prototxt) if pretrained_model is not None: print ('Loading pretrained model ' 'weights from {:s}').format(pretrained_model) self.solver.net.copy_from(pretrained_model) self.solver_param = caffe_pb2.SolverParameter() with open(solver_prototxt, 'rt') as f: pb2.text_format.Merge(f.read(), self.solver_param) #所有的前面的数据准备工作都是为了这一句话,将roidb设置进去,接下来就正式进入剖析训练过程的部分了。\ self.solver.net.layers[0].set_roidb(roidb)
snapshot
自主实现了snapshot,精读的意义不大。
def snapshot(self): """Take a snapshot of the network after unnormalizing the learned bounding-box regression weights. This enables easy use at test-time. """ net = self.solver.net scale_bbox_params = (cfg.TRAIN.BBOX_REG and cfg.TRAIN.BBOX_NORMALIZE_TARGETS and net.params.has_key('bbox_pred')) if scale_bbox_params: # save original values orig_0 = net.params['bbox_pred'][0].data.copy() orig_1 = net.params['bbox_pred'][1].data.copy() # scale and shift with bbox reg unnormalization; then save snapshot net.params['bbox_pred'][0].data[...] = \ (net.params['bbox_pred'][0].data * self.bbox_stds[:, np.newaxis]) net.params['bbox_pred'][1].data[...] = \ (net.params['bbox_pred'][1].data * self.bbox_stds + self.bbox_means) infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX if cfg.TRAIN.SNAPSHOT_INFIX != else ) filename = (self.solver_param.snapshot_prefix + infix + '_iter_{:d}'.format(self.solver.iter) + '.caffemodel') filename = os.path.join(self.output_dir, filename) net.save(str(filename)) print 'Wrote snapshot to: {:s}'.format(filename) if scale_bbox_params: # restore net to original state net.params['bbox_pred'][0].data[...] = orig_0 net.params['bbox_pred'][1].data[...] = orig_1 return filename
train_model
训练主流程,打印了一些时间等信息,并控制了snapshot的过程。
def train_model(self, max_iters): """Network training loop.""" last_snapshot_iter = -1 timer = Timer() model_paths = [] while self.solver.iter < max_iters: # Make one SGD update timer.tic() self.solver.step(1) timer.toc() if self.solver.iter % (10 * self.solver_param.display) == 0: print 'speed: {:.3f}s / iter'.format(timer.average_time) if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = self.solver.iter model_paths.append(self.snapshot()) if last_snapshot_iter != self.solver.iter: model_paths.append(self.snapshot()) return model_paths
get_training_roidb
这个函数(如果设置了)将roidb中的每张图片水平翻转,并添加回去,减少了过拟合的可能性,以及调用prepare_roidb做了些准备性的工作。
def get_training_roidb(imdb): """Returns a roidb (Region of Interest database) for use in training.""" if cfg.TRAIN.USE_FLIPPED: print 'Appending horizontally-flipped training examples...' imdb.append_flipped_images() print 'done' print 'Preparing training data...' rdl_roidb.prepare_roidb(imdb) print 'done' return imdb.roidb
filter_roidb
该函数中定义了一个is_valid函数,用于判断roidb中的每个entry是否合理,合 理定义为至少有一个前景box或背景box。
roidb全是groudtruth时,因为box与对应的类的重合度(overlaps)显然为1,也就是说roidb起码要有一个标记类。
如果roidb包含了一些proposal,overlaps在[BG_THRESH_LO, BG_THRESH_HI]之间的都将被认为是背景,大于FG_THRESH才被认为是前景,roidb 至少要有一个前景或背景,否则将被过滤掉。
将没用的roidb过滤掉以后,返回的就是filtered_roidb
def filter_roidb(roidb):"""Remove roidb entries that have no usable RoIs.""" def is_valid(entry): # Valid images have: # (1) At least one foreground RoI OR # (2) At least one background RoI overlaps = entry['max_overlaps'] # find boxes with sufficient overlap fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0] # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI) bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) & (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0] # image is only valid if such boxes exist valid = len(fg_inds) > 0 or len(bg_inds) > 0 return valid num = len(roidb) filtered_roidb = [entry for entry in roidb if is_valid(entry)] num_after = len(filtered_roidb) print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after, num, num_after) return filtered_roidb
train_net
该函数通过接收不同的solver以及数据进行网络的训练
def train_net(solver_prototxt, roidb, output_dir,pretrained_model=None, max_iters=40000): """Train a Fast R-CNN network.""" roidb = filter_roidb(roidb) sw = SolverWrapper(solver_prototxt, roidb, output_dir, pretrained_model=pretrained_model) print 'Solving...' model_paths = sw.train_model(max_iters) print 'done solving' return model_paths
0 0
- py-faster-rcnn源码解读系列(三)——train.py
- py-faster-rcnn源码解读系列(一)——train_faster_rcnn_alt_opt.py
- py-faster-rcnn源码解读系列(二)——pascal_voc.py
- py-faster-rcnn源码解读系列(四)——anchor_target_layer.py
- py-faster-rcnn源码解读系列
- py-faster-rcnn源码解读系列
- py-faster-rcnn源码解读系列(五)——stage1_rpn_train.pt
- faster rcnn源码解读(三)train_faster_rcnn_alt_opt.py
- faster rcnn源码解读(三)train_faster_rcnn_alt_opt.py
- Faster RCNN train.py
- py-faster-rcnn详解(3)——train.py接口说明
- faster rcnn中train.py
- Faster RCNN minibatch.py解读
- Faster-RCNN_TF代码解读3:train.py
- 【py-faster-rcnn】各函数作用解读
- py-faster-rcnn测试流程解读
- py-faster-rcnn流程(1)——准备阶段
- py-faster-rcnn代码roidb.py的解读
- POJ 3314 Plaque Pack (模拟)
- 面向对象,控制访问
- 6. URL (2)
- 树形list(菜单树)递归遍历list
- Android Retrofit 实现(图文上传)文字(参数)和多张图片一起上传
- py-faster-rcnn源码解读系列(三)——train.py
- STDDEV([distinct|all]x)
- Android-Dialog监听触摸外部事件
- 数论杂记
- AVG([distinct|all]x)
- 为什么要用单例模式而不是静态方法
- 京东信息
- java基础学习之图形绘制
- to_single_byte(c1)