tf-faster-rcnn代码理解之trianval_net.py

来源:互联网 发布:里欧万塔 知乎 编辑:程序博客网 时间:2024/06/06 15:10

原始工程代码是通过tf-faster-rcnn\experiments\scripts目录下的train_faster_rcnn.sh调用tf-faster-rcnn\tools\trainval_net.py进行模型训练。为了方便使用pycharm对整个训练工程进行调试,故修改trianval_net.py使之不需要shell脚本引导,可以直接运行。修改之后的代码如下:

# --------------------------------------------------------# Tensorflow Faster R-CNN# Licensed under The MIT License [see LICENSE for details]# Written by Zheqi He, Xinlei Chen, based on code from Ross Girshick# --------------------------------------------------------from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport _init_pathsfrom model.train_val import get_training_roidb, train_netfrom model.config import cfg, cfg_from_file, cfg_from_list, get_output_dir, get_output_tb_dirfrom datasets.factory import get_imdbimport datasets.imdbimport argparseimport pprintimport numpy as npimport sysimport tensorflow as tffrom nets.vgg16 import vgg16from nets.resnet_v1 import resnetv1class args:  """  Parse input arguments  """  cfg_file = '/home/whao/tf-faster-rcnn/experiments/cfgs/vgg16.yml'  weight = '/home/whao/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt'  imdb_name = 'voc_2007_trainval'  imdbval_name = 'voc_2007_test'  max_iters = 100000  tag = None  net = 'vgg16'  set_cfgs = ['ANCHOR_SCALES', '[8,16,32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'TRAIN.STEPSIZE', '50000']def combined_roidb(imdb_names):  """  Combine multiple roidbs  """  def get_roidb(imdb_name):    imdb = get_imdb(imdb_name)    print('Loaded dataset `{:s}` for training'.format(imdb.name))    imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)    print('Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD))    roidb = get_training_roidb(imdb)    return roidb  roidbs = [get_roidb(s) for s in imdb_names.split('+')]  roidb = roidbs[0]  if len(roidbs) > 1:    for r in roidbs[1:]:      roidb.extend(r)    tmp = get_imdb(imdb_names.split('+')[1])    imdb = datasets.imdb.imdb(imdb_names, tmp.classes)  else:    imdb = get_imdb(imdb_names)  return imdb, roidbif __name__ == '__main__':#  args = parse_args()  print('Called with args:')  if args.cfg_file is not None:    cfg_from_file(args.cfg_file)  if args.set_cfgs is not None:    cfg_from_list(args.set_cfgs)  print('Using config:')  pprint.pprint(cfg)  np.random.seed(cfg.RNG_SEED)  # train set  imdb, roidb = combined_roidb(args.imdb_name)  print('{:d} roidb entries'.format(len(roidb)))  # output directory where the models are saved  output_dir = get_output_dir(imdb, args.tag)  print('Output will be saved to `{:s}`'.format(output_dir))  # tensorboard directory where the summaries are saved during training  tb_dir = get_output_tb_dir(imdb, args.tag)  print('TensorFlow summaries will be saved to `{:s}`'.format(tb_dir))  # also add the validation set, but with no flipping images  orgflip = cfg.TRAIN.USE_FLIPPED  cfg.TRAIN.USE_FLIPPED = False  _, valroidb = combined_roidb(args.imdbval_name)  print('{:d} validation roidb entries'.format(len(valroidb)))  cfg.TRAIN.USE_FLIPPED = orgflip  # load network  if args.net == 'vgg16':    net = vgg16(batch_size=cfg.TRAIN.IMS_PER_BATCH)  else:    raise NotImplementedError      train_net(net, imdb, roidb, valroidb, output_dir, tb_dir,            pretrained_model=args.weight,            max_iters=args.max_iters)
以上代码中,定义了args类代替shell传参。首先需要把训练集按照pascal voc的格式处理好,包括文件名与标签个事和Main中的txt的文件。

代码的执行流程是先读取cfg_file所指定的yml文件来配置部分超参量。执行函数为cfg_from_file(args.cfg_file),它把yml中的超参数合并到config.py中定义的__C对象中,它是类EasyDict的对象。

然后,通过cfg_from_list(args.set_cfgs)配置__C对象中的变量。

接下来,开始处理训练集,通过combined_roidb(args.imdb_name)收集训练集,它通过调用lib/datasets/factory.py中的get_imdb()获得数据集,获得类pascal_voc的对象imdb,再设置区域推荐的方式,默认为gt,通过lib/model/train_val.py中的函数get_training_roidb()获得roidb,即每张图片中的区域推荐样本,其为实际为imdb中的一个变量。打印出区域推荐样本的数量

接下来设置训练好的模型和tensorboard文件的存储路径,再获取验证集的数据,前面的训练的数据是经过数据增强的,每张图片都经过旋转,验证集不进行数据增强。

接下来,配置vgg16网络的batch数量,默认是设置为1。

最后调用train_val.py中的train_net()函数开启训练。

未完待续。

原创粉丝点击