Faster Rcnn源码阅读分析(TF+python版)
来源:互联网 发布:js时间计算 编辑:程序博客网 时间:2024/05/16 09:31
Faster Rcnn源码阅读分析
———源码地址:https://github.com/CharlesShang/TFFRCNN
1)直接运行已经训练好的模型
The demo performs detection using a VGG16 network trained for detection on PASCAL VOC 2007.
python demo.py --model /models/VGGnet_fast_rcnn_iter_150000.ckpt (后面为本人自己的模型下载存放路径)
测试的数据集存放在/data/demo下。
2)自己训练模型
### Training on Pascal VOC 20071. Download the training, validation, test data and VOCdevkit ```Shell wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCdevkit_08-Jun-2007.tar ```2. Extract all of these tars into one directory named `VOCdevkit` ```Shell tar xvf VOCtrainval_06-Nov-2007.tar tar xvf VOCtest_06-Nov-2007.tar tar xvf VOCdevkit_08-Jun-2007.tar ```3. It should have this basic structure ```Shell $VOCdevkit/ # development kit $VOCdevkit/VOCcode/ # VOC utility code $VOCdevkit/VOC2007 # image sets, annotations, etc. # ... and several other directories ... ```4. Create symlinks for the PASCAL VOC dataset ```Shell cd $TFFRCNN/data ln -s $VOCdevkit VOCdevkit2007 ```5. Download pre-trained model [VGG16](https://drive.google.com/open?id=0ByuDEGFYmWsbNVF5eExySUtMZmM) and put it in the path `./data/pretrain_model/VGG_imagenet.npy`6. Run training scripts ```Shell cd $TFFRCNN python ./faster_rcnn/train_net.py --gpu 0 --weights ./data/pretrain_model/VGG_imagenet.npy --imdb voc_2007_trainval --iters 70000 --cfg ./experiments/cfgs/faster_rcnn_end2end.yml --network VGGnet_train --set EXP_DIR exp_dir ```7. Run a profiling ```Shell cd $TFFRCNN # install a visualization tool sudo apt-get install graphviz ./experiments/profiling/run_profiling.sh # generate an image ./experiments/profiling/profile.png
该部分训练代码主要针对VOC2007数据集,若想训练自己的数据集可将其转换为VOC格式。
3)源码阅读笔记
(1)TFFRCNN 下的train_net.py
# --------------------------------------------------------# Fast R-CNN# Copyright (c) 2015 Microsoft# Licensed under The MIT License [see LICENSE for details]# Written by Ross Girshick# --------------------------------------------------------"""Train a Fast R-CNN network on a region of interest database."""import argparse #argparse是python用于解析命令行参数和选项的标准模块;import pprint #用于打印python数据结构类和方法;import numpy as npimport pdb #pdb模块让在用文本编辑器写脚本的情况下进行debug;import sysimport os.paththis_dir = os.path.dirname(__file__)sys.path.insert(0, this_dir + '/..')# for p in sys.path: print p# print (this_dir)from lib.fast_rcnn.train import get_training_roidb, train_netfrom lib.fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir, get_log_dirfrom lib.datasets.factory import get_imdbfrom lib.networks.factory import get_networkfrom lib.fast_rcnn.config import cfgdef parse_args(): """ Parse input arguments """ parser = argparse.ArgumentParser(description='Train a Fast R-CNN network') parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=0, type=int) parser.add_argument('--solver', dest='solver', help='solver prototxt', default=None, type=str) parser.add_argument('--iters', dest='max_iters', help='number of iterations to train', default=70000, type=int) parser.add_argument('--weights', dest='pretrained_model', help='initialize with pretrained model weights', default=None, type=str) parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default=None, type=str) parser.add_argument('--imdb', dest='imdb_name', help='dataset to train on', default='kitti_train', type=str) parser.add_argument('--rand', dest='randomize', help='randomize (do not use a fixed seed)', action='store_true') parser.add_argument('--network', dest='network_name', help='name of the network', default=None, type=str) parser.add_argument('--set', dest='set_cfgs', help='set config keys', default=None, nargs=argparse.REMAINDER) parser.add_argument('--restore', dest='restore', help='restore or not', default=1, type=int) if len(sys.argv) == 1: parser.print_help() # sys.exit(1) args = parser.parse_args() return argsif __name__ == '__main__': args = parse_args() print('Called with args:') print(args) if args.cfg_file is not None: cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) print('Using config:') pprint.pprint(cfg) if not args.randomize: # fix the random seeds (numpy and caffe) for reproducibility np.random.seed(cfg.RNG_SEED) imdb = get_imdb(args.imdb_name) print 'Loaded dataset `{:s}` for training'.format(imdb.name) roidb = get_training_roidb(imdb) #得到用于训练的roidb,定义在train.py,进行了水平翻转,以及为原始roidb添加了一些说明性的属性 output_dir = get_output_dir(imdb, None) log_dir = get_log_dir(imdb) print 'Output will be saved to `{:s}`'.format(output_dir) print 'Logs will be saved to `{:s}`'.format(log_dir) device_name = '/gpu:{:d}'.format(args.gpu_id) print device_name network = get_network(args.network_name) print 'Use network `{:s}` in training'.format(args.network_name) train_net(network, imdb, roidb, output_dir=output_dir, log_dir=log_dir, pretrained_model=args.pretrained_model, max_iters=args.max_iters, restore=bool(int(args.restore)))
(1) get_imdb,get_roidb函数:
http://blog.csdn.net/sloanqin/article/details/51537713
http://www.cnblogs.com/alanma/p/6802835.html
http://www.cnblogs.com/alanma/p/6803713.html
(2)cfg(config.py)模块解读:
http://www.cnblogs.com/alanma/p/6800944.html
阅读全文
0 0
- Faster Rcnn源码阅读分析(TF+python版)
- faster rcnn 源码阅读
- faster RCNN python 版安装
- faster RCNN python版安装
- 学习Faster-RCNN (一)跑Faster-RCNN的python版demo
- tf-faster-rcnn代码理解
- 代码阅读:Faster RCNN
- faster rcnn代码阅读
- 论文阅读Faster RCNN
- faster rcnn 源码解读
- faster rcnn 源码解读
- Faster Rcnn 源码记录
- faster rcnn源码理解
- Python版Faster-RCNN安装配置
- 运行caffe版(python)faster RCNN
- Faster RCNN代码理解(Python)
- Faster RCNN代码理解(Python)
- Faster RCNN代码理解(Python)
- lua语言的学习二
- POJ2115 C Looooops ——模线性方程(扩展gcd)
- [leetcode]547. Friend Circles
- iic
- java线程池深入一
- Faster Rcnn源码阅读分析(TF+python版)
- 学习笔记_01生成对抗网络(GANs)
- 哈夫曼树
- linux poll函数的使用
- VS2017:头停止点需要位于全局范围。未生成 IntelliSense PCH 文件。错误处理
- for循环-c基础第六课
- 欧拉回路Fleury算法模板
- (3)写简单发布节点和订阅节点
- 剑指offer——连续子数组的最大和