faster rcnn源码解读(三)train_faster_rcnn_alt_opt.py

来源:互联网 发布:知乎 家庭交换机 编辑:程序博客网 时间:2024/05/16 09:02
转载自:faster rcnn源码解读(三)train_faster_rcnn_alt_opt.py - 野孩子的专栏 - 博客频道 - CSDN.NET

http://blog.csdn.net/u010668907/article/details/51945320

faster用python版本的https://github.com/rbgirshick/py-faster-rcnn

train_faster_rcnn_alt_opt.py源码在https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/train_faster_rcnn_alt_opt.py

faster rcnn训练的开始是:faster_rcnn_alt_opt.sh。下面命令是训练的,还有它的参数说明。

1.调用最初脚本的说明

cd $FRCN_ROOT

# ./experiments/scripts/faster_rcnn_alt_opt.sh  GPU  NET  DATASET [options args to {train,test}_net.py]

# GPU_ID is the GPU you want to train on

# NET in {ZF, VGG_CNN_M_1024, VGG16} is the network arch to use

# DATASET is only pascal_voc for now

train_faster_rcnn_alt_opt.py的源码:

[python] view plain copy
 print?在CODE上查看代码片派生到我的代码片
  1. #!/usr/bin/env python  
  2.   
  3. # --------------------------------------------------------  
  4. # Faster R-CNN  
  5. # Copyright (c) 2015 Microsoft  
  6. # Licensed under The MIT License [see LICENSE for details]  
  7. # Written by Ross Girshick  
  8. # --------------------------------------------------------  
  9.   
  10. """Train a Faster R-CNN network using alternating optimization. 
  11. This tool implements the alternating optimization algorithm described in our 
  12. NIPS 2015 paper ("Faster R-CNN: Towards Real-time Object Detection with Region 
  13. Proposal Networks." Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun.) 
  14. """  
  15.   
  16. import _init_paths  
  17. from fast_rcnn.train import get_training_roidb, train_net  
  18. from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir  
  19. from datasets.factory import get_imdb  
  20. from rpn.generate import imdb_proposals  
  21. import argparse  
  22. import pprint  
  23. import numpy as np  
  24. import sys, os  
  25. import multiprocessing as mp  
  26. import cPickle  
  27. import shutil  
  28.   
  29. def parse_args():  
  30.     """ 
  31.     Parse input arguments 
  32.     """  
  33.     parser = argparse.ArgumentParser(description='Train a Faster R-CNN network')  
  34.     parser.add_argument('--gpu', dest='gpu_id',  
  35.                         help='GPU device id to use [0]',  
  36.                         default=0, type=int)  
  37.     parser.add_argument('--net_name', dest='net_name',  
  38.                         help='network name (e.g., "ZF")',  
  39.                         default=None, type=str)  
  40.     parser.add_argument('--weights', dest='pretrained_model',  
  41.                         help='initialize with pretrained model weights',  
  42.                         default=None, type=str)  
  43.     parser.add_argument('--cfg', dest='cfg_file',  
  44.                         help='optional config file',  
  45.                         default=None, type=str)  
  46.     parser.add_argument('--imdb', dest='imdb_name',  
  47.                         help='dataset to train on',  
  48.                         default='voc_2007_trainval', type=str)  
  49.     parser.add_argument('--set', dest='set_cfgs',  
  50.                         help='set config keys', default=None,  
  51.                         nargs=argparse.REMAINDER)  
  52.   
  53.     if len(sys.argv) == 1:  
  54.         parser.print_help()  
  55.         sys.exit(1)  
  56.   
  57.     args = parser.parse_args()  
  58.     return args  
  59.   
  60. def get_roidb(imdb_name, rpn_file=None):  
  61.     imdb = get_imdb(imdb_name)  
  62.     print 'Loaded dataset `{:s}` for training'.format(imdb.name)  
  63.     imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)  
  64.     print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)  
  65.     if rpn_file is not None:  
  66.         imdb.config['rpn_file'] = rpn_file  
  67.     roidb = get_training_roidb(imdb)  
  68.     return roidb, imdb  
  69.   
  70. def get_solvers(net_name):  
  71.     # Faster R-CNN Alternating Optimization  
  72.     n = 'faster_rcnn_alt_opt'  
  73.     # Solver for each training stage  
  74.     solvers = [[net_name, n, 'stage1_rpn_solver60k80k.pt'],  
  75.                [net_name, n, 'stage1_fast_rcnn_solver30k40k.pt'],  
  76.                [net_name, n, 'stage2_rpn_solver60k80k.pt'],  
  77.                [net_name, n, 'stage2_fast_rcnn_solver30k40k.pt']]  
  78.     solvers = [os.path.join(cfg.MODELS_DIR, *s) for s in solvers]  
  79.     # Iterations for each training stage  
  80.     max_iters = [80000400008000040000]  
  81.     # max_iters = [100, 100, 100, 100]  
  82.     # Test prototxt for the RPN  
  83.     rpn_test_prototxt = os.path.join(  
  84.         cfg.MODELS_DIR, net_name, n, 'rpn_test.pt')  
  85.     return solvers, max_iters, rpn_test_prototxt  
  86.   
  87. # ------------------------------------------------------------------------------  
  88. # Pycaffe doesn't reliably free GPU memory when instantiated nets are discarded  
  89. # (e.g. "del net" in Python code). To work around this issue, each training  
  90. # stage is executed in a separate process using multiprocessing.Process.  
  91. # ------------------------------------------------------------------------------  
  92.   
  93. def _init_caffe(cfg):  
  94.     """Initialize pycaffe in a training process. 
  95.     """  
  96.   
  97.     import caffe  
  98.     # fix the random seeds (numpy and caffe) for reproducibility  
  99.     np.random.seed(cfg.RNG_SEED)  
  100.     caffe.set_random_seed(cfg.RNG_SEED)  
  101.     # set up caffe  
  102.     caffe.set_mode_gpu()  
  103.     caffe.set_device(cfg.GPU_ID)  
  104.   
  105. def train_rpn(queue=None, imdb_name=None, init_model=None, solver=None,  
  106.               max_iters=None, cfg=None):  
  107.     """Train a Region Proposal Network in a separate training process. 
  108.     """  
  109.   
  110.     # Not using any proposals, just ground-truth boxes  
  111.     cfg.TRAIN.HAS_RPN = True  
  112.     cfg.TRAIN.BBOX_REG = False  # applies only to Fast R-CNN bbox regression  
  113.     cfg.TRAIN.PROPOSAL_METHOD = 'gt'  
  114.     cfg.TRAIN.IMS_PER_BATCH = 1  
  115.     print 'Init model: {}'.format(init_model)  
  116.     print('Using config:')  
  117.     pprint.pprint(cfg)  
  118.   
  119.     import caffe  
  120.     _init_caffe(cfg)  
  121.   
  122.     roidb, imdb = get_roidb(imdb_name)  
  123.     print 'roidb len: {}'.format(len(roidb))  
  124.     output_dir = get_output_dir(imdb)  
  125.     print 'Output will be saved to `{:s}`'.format(output_dir)  
  126.   
  127.     model_paths = train_net(solver, roidb, output_dir,  
  128.                             pretrained_model=init_model,  
  129.                             max_iters=max_iters)  
  130.     # Cleanup all but the final model  
  131.     for i in model_paths[:-1]:  
  132.         os.remove(i)  
  133.     rpn_model_path = model_paths[-1]  
  134.     # Send final model path through the multiprocessing queue  
  135.     queue.put({'model_path': rpn_model_path})  
  136.   
  137. def rpn_generate(queue=None, imdb_name=None, rpn_model_path=None, cfg=None,  
  138.                  rpn_test_prototxt=None):  
  139.     """Use a trained RPN to generate proposals. 
  140.     """  
  141.   
  142.     cfg.TEST.RPN_PRE_NMS_TOP_N = -1     # no pre NMS filtering  
  143.     cfg.TEST.RPN_POST_NMS_TOP_N = 2000  # limit top boxes after NMS  
  144.     print 'RPN model: {}'.format(rpn_model_path)  
  145.     print('Using config:')  
  146.     pprint.pprint(cfg)  
  147.   
  148.     import caffe  
  149.     _init_caffe(cfg)  
  150.   
  151.     # NOTE: the matlab implementation computes proposals on flipped images, too.  
  152.     # We compute them on the image once and then flip the already computed  
  153.     # proposals. This might cause a minor loss in mAP (less proposal jittering).  
  154.     imdb = get_imdb(imdb_name)  
  155.     print 'Loaded dataset `{:s}` for proposal generation'.format(imdb.name)  
  156.   
  157.     # Load RPN and configure output directory  
  158.     rpn_net = caffe.Net(rpn_test_prototxt, rpn_model_path, caffe.TEST)  
  159.     output_dir = get_output_dir(imdb)  
  160.     print 'Output will be saved to `{:s}`'.format(output_dir)  
  161.     # Generate proposals on the imdb  
  162.     rpn_proposals = imdb_proposals(rpn_net, imdb)  
  163.     # Write proposals to disk and send the proposal file path through the  
  164.     # multiprocessing queue  
  165.     rpn_net_name = os.path.splitext(os.path.basename(rpn_model_path))[0]  
  166.     rpn_proposals_path = os.path.join(  
  167.         output_dir, rpn_net_name + '_proposals.pkl')  
  168.     with open(rpn_proposals_path, 'wb') as f:  
  169.         cPickle.dump(rpn_proposals, f, cPickle.HIGHEST_PROTOCOL)  
  170.     print 'Wrote RPN proposals to {}'.format(rpn_proposals_path)  
  171.     queue.put({'proposal_path': rpn_proposals_path})  
  172.   
  173. def train_fast_rcnn(queue=None, imdb_name=None, init_model=None, solver=None,  
  174.                     max_iters=None, cfg=None, rpn_file=None):  
  175.     """Train a Fast R-CNN using proposals generated by an RPN. 
  176.     """  
  177.   
  178.     cfg.TRAIN.HAS_RPN = False           # not generating prosals on-the-fly  
  179.     cfg.TRAIN.PROPOSAL_METHOD = 'rpn'   # use pre-computed RPN proposals instead  
  180.     cfg.TRAIN.IMS_PER_BATCH = 2  
  181.     print 'Init model: {}'.format(init_model)  
  182.     print 'RPN proposals: {}'.format(rpn_file)  
  183.     print('Using config:')  
  184.     pprint.pprint(cfg)  
  185.   
  186.     import caffe  
  187.     _init_caffe(cfg)  
  188.   
  189.     roidb, imdb = get_roidb(imdb_name, rpn_file=rpn_file)  
  190.     output_dir = get_output_dir(imdb)  
  191.     print 'Output will be saved to `{:s}`'.format(output_dir)  
  192.     # Train Fast R-CNN  
  193.     model_paths = train_net(solver, roidb, output_dir,  
  194.                             pretrained_model=init_model,  
  195.                             max_iters=max_iters)  
  196.     # Cleanup all but the final model  
  197.     for i in model_paths[:-1]:  
  198.         os.remove(i)  
  199.     fast_rcnn_model_path = model_paths[-1]  
  200.     # Send Fast R-CNN model path over the multiprocessing queue  
  201.     queue.put({'model_path': fast_rcnn_model_path})  
  202.   
  203. if __name__ == '__main__':  
  204.     args = parse_args()  
  205.   
  206.     print('Called with args:')  
  207.     print(args)  
  208.   
  209.     if args.cfg_file is not None:  
  210.         cfg_from_file(args.cfg_file)  
  211.     if args.set_cfgs is not None:  
  212.         cfg_from_list(args.set_cfgs)  
  213.     cfg.GPU_ID = args.gpu_id  
  214.   
  215.     # --------------------------------------------------------------------------  
  216.     # Pycaffe doesn't reliably free GPU memory when instantiated nets are  
  217.     # discarded (e.g. "del net" in Python code). To work around this issue, each  
  218.     # training stage is executed in a separate process using  
  219.     # multiprocessing.Process.  
  220.     # --------------------------------------------------------------------------  
  221.   
  222.     # queue for communicated results between processes  
  223.     mp_queue = mp.Queue()  
  224.     # solves, iters, etc. for each training stage  
  225.     solvers, max_iters, rpn_test_prototxt = get_solvers(args.net_name)  
  226.   
  227.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  228.     print 'Stage 1 RPN, init from ImageNet model'  
  229.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  230.   
  231.     cfg.TRAIN.SNAPSHOT_INFIX = 'stage1'  
  232.     mp_kwargs = dict(  
  233.             queue=mp_queue,  
  234.             imdb_name=args.imdb_name,  
  235.             init_model=args.pretrained_model,  
  236.             solver=solvers[0],  
  237.             max_iters=max_iters[0],  
  238.             cfg=cfg)  
  239.     p = mp.Process(target=train_rpn, kwargs=mp_kwargs)  
  240.     p.start()  
  241.     rpn_stage1_out = mp_queue.get()  
  242.     p.join()  
  243.   
  244.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  245.     print 'Stage 1 RPN, generate proposals'  
  246.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  247.   
  248.     mp_kwargs = dict(  
  249.             queue=mp_queue,  
  250.             imdb_name=args.imdb_name,  
  251.             rpn_model_path=str(rpn_stage1_out['model_path']),  
  252.             cfg=cfg,  
  253.             rpn_test_prototxt=rpn_test_prototxt)  
  254.     p = mp.Process(target=rpn_generate, kwargs=mp_kwargs)  
  255.     p.start()  
  256.     rpn_stage1_out['proposal_path'] = mp_queue.get()['proposal_path']  
  257.     p.join()  
  258.   
  259.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  260.     print 'Stage 1 Fast R-CNN using RPN proposals, init from ImageNet model'  
  261.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  262.   
  263.     cfg.TRAIN.SNAPSHOT_INFIX = 'stage1'  
  264.     mp_kwargs = dict(  
  265.             queue=mp_queue,  
  266.             imdb_name=args.imdb_name,  
  267.             init_model=args.pretrained_model,  
  268.             solver=solvers[1],  
  269.             max_iters=max_iters[1],  
  270.             cfg=cfg,  
  271.             rpn_file=rpn_stage1_out['proposal_path'])  
  272.     p = mp.Process(target=train_fast_rcnn, kwargs=mp_kwargs)  
  273.     p.start()  
  274.     fast_rcnn_stage1_out = mp_queue.get()  
  275.     p.join()  
  276.   
  277.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  278.     print 'Stage 2 RPN, init from stage 1 Fast R-CNN model'  
  279.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  280.   
  281.     cfg.TRAIN.SNAPSHOT_INFIX = 'stage2'  
  282.     mp_kwargs = dict(  
  283.             queue=mp_queue,  
  284.             imdb_name=args.imdb_name,  
  285.             init_model=str(fast_rcnn_stage1_out['model_path']),  
  286.             solver=solvers[2],  
  287.             max_iters=max_iters[2],  
  288.             cfg=cfg)  
  289.     p = mp.Process(target=train_rpn, kwargs=mp_kwargs)  
  290.     p.start()  
  291.     rpn_stage2_out = mp_queue.get()  
  292.     p.join()  
  293.   
  294.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  295.     print 'Stage 2 RPN, generate proposals'  
  296.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  297.   
  298.     mp_kwargs = dict(  
  299.             queue=mp_queue,  
  300.             imdb_name=args.imdb_name,  
  301.             rpn_model_path=str(rpn_stage2_out['model_path']),  
  302.             cfg=cfg,  
  303.             rpn_test_prototxt=rpn_test_prototxt)  
  304.     p = mp.Process(target=rpn_generate, kwargs=mp_kwargs)  
  305.     p.start()  
  306.     rpn_stage2_out['proposal_path'] = mp_queue.get()['proposal_path']  
  307.     p.join()  
  308.   
  309.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  310.     print 'Stage 2 Fast R-CNN, init from stage 2 RPN R-CNN model'  
  311.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  312.   
  313.     cfg.TRAIN.SNAPSHOT_INFIX = 'stage2'  
  314.     mp_kwargs = dict(  
  315.             queue=mp_queue,  
  316.             imdb_name=args.imdb_name,  
  317.             init_model=str(rpn_stage2_out['model_path']),  
  318.             solver=solvers[3],  
  319.             max_iters=max_iters[3],  
  320.             cfg=cfg,  
  321.             rpn_file=rpn_stage2_out['proposal_path'])  
  322.     p = mp.Process(target=train_fast_rcnn, kwargs=mp_kwargs)  
  323.     p.start()  
  324.     fast_rcnn_stage2_out = mp_queue.get()  
  325.     p.join()  
  326.   
  327.     # Create final model (just a copy of the last stage)  
  328.     final_path = os.path.join(  
  329.             os.path.dirname(fast_rcnn_stage2_out['model_path']),  
  330.             args.net_name + '_faster_rcnn_final.caffemodel')  
  331.     print 'cp {} -> {}'.format(  
  332.             fast_rcnn_stage2_out['model_path'], final_path)  
  333.     shutil.copy(fast_rcnn_stage2_out['model_path'], final_path)  
  334.     print 'Final model: {}'.format(final_path)  

2. train_faster_rcnn_alt_opt.py的部分参数说明

net_name:      {ZF, VGG_CNN_M_1024, VGG16}

pretrained_model:      data/imagenet_models/${net_name}.v2.caffemodel

cfg_file:     experiments/cfgs/faster_rcnn_alt_opt.yml

imdb_name:     "voc_2007_trainval" or "voc_2007_test"

 

cfg.TRAIN.HAS_RPN = True表示用xml提供的propoal

cfg是配置文件,它的默认值放在上面的cfg_file里,其他还可以自己写配置文件之后与默认配置文件融合。

  2.1 net_name是用get_solvers()找到网络。还要用到cfg的参数MODELS_DIR

    例子是joinMODELS_DIR, net_name, 'faster_rcnn_alt_opt', 'stage1_rpn_solver60k80k.pt'

  2.2 imdb_namefactory中被拆成‘2007’(year)和‘trainval/test’(split)到类pascal_voc中产生相应的imdb

  2.3 整个step的大致流程:

(ImageNet model)->stage1_rpn_train->rpn_test

                                                                           |(proposal_path)

                    (ImageNetmodel)->stage1_fast_rcnn_train-> stage2_rpn_train-> rpn_test-> stage2_fast_rcnn_train

  2.4 数据imdbroidb

  roidb原本是imdb的一个属性,但imdb其实是为了计算roidb存在的,他所有的其他属性和方法都是为了计算roidb


0 0
原创粉丝点击