Fast RCNN训练阶段代码解析
来源:互联网 发布:软件学院为什么分数低 编辑:程序博客网 时间:2024/06/18 08:48
- 首先是入口文件trian_net.py,真正处理数据的文件都在lib文件里,包括数据集制作的文件在lib/datasets下,网络训练测试的文件在lib/fast_rcnn下,lib/roi_data_layer是用python实现的网络的输入层。
parse_args函数解析输入参数:网络参数定义,初始化模型(这两项没有默认值必须自己指定),显卡号,最大迭代次数,训练数据位置等。def parse_args(): """ Parse input arguments """ parser = argparse.ArgumentParser(description='Train a Fast R-CNN network') parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=0, type=int) parser.add_argument('--solver', dest='solver', help='solver prototxt', default=None, type=str) parser.add_argument('--iters', dest='max_iters', help='number of iterations to train', default=40000, type=int) parser.add_argument('--weights', dest='pretrained_model', help='initialize with pretrained model weights', default=None, type=str) parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default=None, type=str) parser.add_argument('--imdb', dest='imdb_name', help='dataset to train on', default='voc_2007_trainval', type=str) parser.add_argument('--rand', dest='randomize', help='randomize (do not use a fixed seed)', action='store_true') parser.add_argument('--set', dest='set_cfgs', help='set config keys', default=None, nargs=argparse.REMAINDER) if len(sys.argv) == 1: parser.print_help() sys.exit(1) args = parser.parse_args() return args程序入口,可以看做main函数if __name__ == '__main__': args = parse_args()#解析输入参数,存入args print('Called with args:') print(args) if args.cfg_file is not None: cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) print('Using config:') pprint.pprint(cfg)#设置caffe if not args.randomize: # fix the random seeds (numpy and caffe) for reproducibility np.random.seed(cfg.RNG_SEED) caffe.set_random_seed(cfg.RNG_SEED) # set up caffe caffe.set_mode_gpu() if args.gpu_id is not None: caffe.set_device(args.gpu_id)#读取训练数据,包括训练图片的位置,物体gt(外接框)坐标,selective search方法产生的proposal(候选框)。调用的是lib/datasets/factory.py中的get_imdb函数 imdb = get_imdb(args.imdb_name) print 'Loaded dataset `{:s}` for training'.format(imdb.name) #上步得到的数据imdb进一步制作成训练时的数据,主要是把图片翻转,扩充训练样本. roidb = get_training_roidb(imdb) output_dir = get_output_dir(imdb, None) print 'Output will be saved to `{:s}`'.format(output_dir)#真正的训练函数,调用lib/fast_rcnn/train.py中的train_net函数 train_net(args.solver, roidb, output_dir, pretrained_model=args.pretrained_model, max_iters=args.max_iters)
2.训练数据读取主函数是imdb=get_imdb(args.imdb_name)函数,在下面前辈的博客中已经很清楚了,请参考。
http://www.cnblogs.com/louyihang-loves-baiyan/archive/2015/10/16/4885659.html
将训练数据读取到imdb变量中只是简单的将数据读入,并没有将样本标注为正负类,数据扩充等操作。紧接着调用 roidb = get_training_roidb(imdb)制作训练数据集,实现在lib/fast_rcnn/train.py中
def get_training_roidb(imdb): """Returns a roidb (Region of Interest database) for use in training.""" if cfg.TRAIN.USE_FLIPPED: print 'Appending horizontally-flipped training examples...' imdb.append_flipped_images()#数据翻转操作,扩充训练数据集 print 'done' print 'Preparing training data...' rdl_roidb.prepare_roidb(imdb)#prepare_roidb函数中max_classes是每个proposal重合度最大的物体gt的类别,max_overlaps是最大重合度。 print 'done' return imdb.roidb
3.lib/roi_data_layer下的网输入层
caffe提供了python,也就是说可以用python实现某一个层,fast_rcnn就用python实现了网络的输入层。
fast_rcnn中实现caffe支持python主要设置了两个文件:fast-rcnn/caffe-fast-rcnn/src/caffe.proto和fast-rcnn/caffe-fast-rcnn/include/caffe/python_layer.hpp文件
. caffe.proto中注册Python层参数:
optional PythonParameter python_param = 130;........// Message that stores parameters used by PythonLayermessage PythonParameter { optional string module = 1; optional string layer = 2; // This value is set to the attribute `param_str_` of your custom // `PythonLayer` object in Python before calling `setup()` method. This could // be a number, a string, a dictionary in Python dict format or JSON etc. You // may parse this string in `setup` method and use them in `forward` and // `backward`. optional string param_str = 3 [default = ''];}
. python_layer.hpp头文件定义了需要python实现的几个输入层的重要函数:setup函数,reshape函数,forward函数,backward函数。
#ifndef CAFFE_PYTHON_LAYER_HPP_#define CAFFE_PYTHON_LAYER_HPP_#include <boost/python.hpp>#include <string>#include <vector>#include "caffe/layer.hpp"namespace bp = boost::python;namespace caffe {#define PYTHON_LAYER_ERROR() { \ PyObject *petype, *pevalue, *petrace; \ PyErr_Fetch(&petype, &pevalue, &petrace); \ bp::object etype(bp::handle<>(bp::borrowed(petype))); \ bp::object evalue(bp::handle<>(bp::borrowed(bp::allow_null(pevalue)))); \ bp::object etrace(bp::handle<>(bp::borrowed(bp::allow_null(petrace)))); \ bp::object sio(bp::import("StringIO").attr("StringIO")()); \ bp::import("traceback").attr("print_exception")( \ etype, evalue, etrace, bp::object(), sio); \ LOG(INFO) << bp::extract<string>(sio.attr("getvalue")())(); \ PyErr_Restore(petype, pevalue, petrace); \ throw; \}template <typename Dtype>class PythonLayer : public Layer<Dtype> { public: PythonLayer(PyObject* self, const LayerParameter& param) : Layer<Dtype>(param), self_(bp::handle<>(bp::borrowed(self))) { } virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { try { self_.attr("param_str_") = bp::str( this->layer_param_.python_param().param_str()); #LayerSetUp函数调用setup函数,具体实现在lib/roi_data_layer/layer.py中的setup函数。下面reshape,forward,backward函数同理 self_.attr("setup")(bottom, top); } catch (bp::error_already_set) { PYTHON_LAYER_ERROR(); } } virtual void Reshape(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { try { self_.attr("reshape")(bottom, top); } catch (bp::error_already_set) { PYTHON_LAYER_ERROR(); } } virtual inline const char* type() const { return "Python"; } protected: virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { try { #forward函数的实现在lib/roi_data_layer/layer.py中 self_.attr("forward")(bottom, top); } catch (bp::error_already_set) { PYTHON_LAYER_ERROR(); } } virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { try { #backward函数的实现在lib/roi_data_layer/layer.py中 self_.attr("backward")(top, propagate_down, bottom); } catch (bp::error_already_set) { PYTHON_LAYER_ERROR(); } } private: bp::object self_;};} // namespace caffe#endif
2 0
- Fast-RCNN解析:训练阶段代码导读
- Fast-RCNN解析:训练阶段代码导读
- Fast-RCNN解析:训练阶段代码导读
- Fast RCNN训练阶段代码解析
- Fast-RCNN解析:训练阶段代码导读
- fast rcnn 代码解析(一)
- fast-rcnn训练实战
- Fast rcnn训练
- fast RCNN训练车型模块
- 研究Fast rcnn代码
- fast-rcnn训练自己数据集以及demo代码解读和总结(面向fast-rcnn初学者)
- fast-rcnn训练自己数据集以及demo代码解读和总结(面向fast-rcnn初学者)
- Fast-RCNN代码解读(0)
- Fast rcnn cpu 训练自己的数据
- Fast-rcnn 训练(1)-安装
- Fast-rcnn 训练(2)- 跑demo
- fast-rcnn训练自己的数据
- Fast RCNN训练自己的数据集
- python学习之unicode编码
- 数据结构复习——线性表的链式存储实现(双向链表)
- studio项目是完全可以转换成eclipse的
- 2006 - MySQL server has gone away
- MainActivity
- Fast RCNN训练阶段代码解析
- 发邮件自动回复本机IP——python版本
- 《Thinkinginjava》第11章-持有对象
- 黑名单来电自动挂断
- 欢迎使用CSDN-markdown编辑器
- Android之Activity的几种跳转方式
- 有一个嵌入式软件开发专家的博客值得关注
- 算法导论_第十章_基本数据结构
- 反编译工具