Faster Rcnn源码阅读分析(TF+python版)

来源:互联网 发布:js时间计算 编辑:程序博客网 时间:2024/05/16 09:31

Faster Rcnn源码阅读分析

———源码地址:https://github.com/CharlesShang/TFFRCNN

1)直接运行已经训练好的模型

The demo performs detection using a VGG16 network trained for detection on PASCAL VOC 2007.

python demo.py --model /models/VGGnet_fast_rcnn_iter_150000.ckpt  (后面为本人自己的模型下载存放路径)

测试的数据集存放在/data/demo下。


2)自己训练模型

### Training on Pascal VOC 20071. Download the training, validation, test data and VOCdevkit    ```Shell    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCdevkit_08-Jun-2007.tar    ```2. Extract all of these tars into one directory named `VOCdevkit`    ```Shell    tar xvf VOCtrainval_06-Nov-2007.tar    tar xvf VOCtest_06-Nov-2007.tar    tar xvf VOCdevkit_08-Jun-2007.tar    ```3. It should have this basic structure    ```Shell    $VOCdevkit/                           # development kit    $VOCdevkit/VOCcode/                   # VOC utility code    $VOCdevkit/VOC2007                    # image sets, annotations, etc.    # ... and several other directories ...    ```4. Create symlinks for the PASCAL VOC dataset    ```Shell    cd $TFFRCNN/data    ln -s $VOCdevkit VOCdevkit2007    ```5. Download pre-trained model [VGG16](https://drive.google.com/open?id=0ByuDEGFYmWsbNVF5eExySUtMZmM) and put it in the path `./data/pretrain_model/VGG_imagenet.npy`6. Run training scripts     ```Shell    cd $TFFRCNN    python ./faster_rcnn/train_net.py --gpu 0 --weights ./data/pretrain_model/VGG_imagenet.npy --imdb voc_2007_trainval --iters 70000 --cfg  ./experiments/cfgs/faster_rcnn_end2end.yml --network VGGnet_train --set EXP_DIR exp_dir    ```7. Run a profiling    ```Shell    cd $TFFRCNN    # install a visualization tool    sudo apt-get install graphviz      ./experiments/profiling/run_profiling.sh     # generate an image ./experiments/profiling/profile.png

该部分训练代码主要针对VOC2007数据集,若想训练自己的数据集可将其转换为VOC格式。

3)源码阅读笔记

(1)TFFRCNN 下的train_net.py

# --------------------------------------------------------# Fast R-CNN# Copyright (c) 2015 Microsoft# Licensed under The MIT License [see LICENSE for details]# Written by Ross Girshick# --------------------------------------------------------"""Train a Fast R-CNN network on a region of interest database."""import argparse #argparse是python用于解析命令行参数和选项的标准模块;import pprint   #用于打印python数据结构类和方法;import numpy as npimport pdb      #pdb模块让在用文本编辑器写脚本的情况下进行debug;import sysimport os.paththis_dir = os.path.dirname(__file__)sys.path.insert(0, this_dir + '/..')# for p in sys.path: print p# print (this_dir)from lib.fast_rcnn.train import get_training_roidb, train_netfrom lib.fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir, get_log_dirfrom lib.datasets.factory import get_imdbfrom lib.networks.factory import get_networkfrom lib.fast_rcnn.config import cfgdef parse_args():    """    Parse input arguments    """    parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')    parser.add_argument('--gpu', dest='gpu_id',                        help='GPU device id to use [0]',                        default=0, type=int)    parser.add_argument('--solver', dest='solver',                        help='solver prototxt',                        default=None, type=str)    parser.add_argument('--iters', dest='max_iters',                        help='number of iterations to train',                        default=70000, type=int)    parser.add_argument('--weights', dest='pretrained_model',                        help='initialize with pretrained model weights',                        default=None, type=str)    parser.add_argument('--cfg', dest='cfg_file',                        help='optional config file',                        default=None, type=str)    parser.add_argument('--imdb', dest='imdb_name',                        help='dataset to train on',                        default='kitti_train', type=str)    parser.add_argument('--rand', dest='randomize',                        help='randomize (do not use a fixed seed)',                        action='store_true')    parser.add_argument('--network', dest='network_name',                        help='name of the network',                        default=None, type=str)    parser.add_argument('--set', dest='set_cfgs',                        help='set config keys', default=None,                        nargs=argparse.REMAINDER)    parser.add_argument('--restore', dest='restore',                        help='restore or not',                        default=1, type=int)    if len(sys.argv) == 1:        parser.print_help()        # sys.exit(1)    args = parser.parse_args()    return argsif __name__ == '__main__':    args = parse_args()    print('Called with args:')    print(args)    if args.cfg_file is not None:        cfg_from_file(args.cfg_file)    if args.set_cfgs is not None:        cfg_from_list(args.set_cfgs)    print('Using config:')    pprint.pprint(cfg)    if not args.randomize:        # fix the random seeds (numpy and caffe) for reproducibility        np.random.seed(cfg.RNG_SEED)    imdb = get_imdb(args.imdb_name)    print 'Loaded dataset `{:s}` for training'.format(imdb.name)    roidb = get_training_roidb(imdb) #得到用于训练的roidb,定义在train.py,进行了水平翻转,以及为原始roidb添加了一些说明性的属性    output_dir = get_output_dir(imdb, None)    log_dir = get_log_dir(imdb)    print 'Output will be saved to `{:s}`'.format(output_dir)    print 'Logs will be saved to `{:s}`'.format(log_dir)    device_name = '/gpu:{:d}'.format(args.gpu_id)    print device_name    network = get_network(args.network_name)    print 'Use network `{:s}` in training'.format(args.network_name)    train_net(network, imdb, roidb,              output_dir=output_dir,              log_dir=log_dir,              pretrained_model=args.pretrained_model,              max_iters=args.max_iters,              restore=bool(int(args.restore)))

(1) get_imdb,get_roidb函数:
http://blog.csdn.net/sloanqin/article/details/51537713
http://www.cnblogs.com/alanma/p/6802835.html
http://www.cnblogs.com/alanma/p/6803713.html

(2)cfg(config.py)模块解读:
http://www.cnblogs.com/alanma/p/6800944.html