Fast RCNN训练阶段代码解析

来源:互联网 发布:软件学院为什么分数低 编辑:程序博客网 时间:2024/06/18 08:48
  1. 首先是入口文件trian_net.py,真正处理数据的文件都在lib文件里,包括数据集制作的文件在lib/datasets下,网络训练测试的文件在lib/fast_rcnn下,lib/roi_data_layer是用python实现的网络的输入层。
parse_args函数解析输入参数:网络参数定义,初始化模型(这两项没有默认值必须自己指定),显卡号,最大迭代次数,训练数据位置等。def parse_args():    """    Parse input arguments    """    parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')    parser.add_argument('--gpu', dest='gpu_id',                        help='GPU device id to use [0]',                        default=0, type=int)    parser.add_argument('--solver', dest='solver',                        help='solver prototxt',                        default=None, type=str)    parser.add_argument('--iters', dest='max_iters',                        help='number of iterations to train',                        default=40000, type=int)    parser.add_argument('--weights', dest='pretrained_model',                        help='initialize with pretrained model weights',                        default=None, type=str)    parser.add_argument('--cfg', dest='cfg_file',                        help='optional config file',                        default=None, type=str)    parser.add_argument('--imdb', dest='imdb_name',                        help='dataset to train on',                        default='voc_2007_trainval', type=str)    parser.add_argument('--rand', dest='randomize',                        help='randomize (do not use a fixed seed)',                        action='store_true')    parser.add_argument('--set', dest='set_cfgs',                        help='set config keys', default=None,                        nargs=argparse.REMAINDER)    if len(sys.argv) == 1:        parser.print_help()        sys.exit(1)    args = parser.parse_args()    return args程序入口,可以看做main函数if __name__ == '__main__':    args = parse_args()#解析输入参数,存入args    print('Called with args:')    print(args)    if args.cfg_file is not None:        cfg_from_file(args.cfg_file)    if args.set_cfgs is not None:        cfg_from_list(args.set_cfgs)    print('Using config:')    pprint.pprint(cfg)#设置caffe    if not args.randomize:        # fix the random seeds (numpy and caffe) for reproducibility        np.random.seed(cfg.RNG_SEED)        caffe.set_random_seed(cfg.RNG_SEED)    # set up caffe    caffe.set_mode_gpu()    if args.gpu_id is not None:        caffe.set_device(args.gpu_id)#读取训练数据,包括训练图片的位置,物体gt(外接框)坐标,selective search方法产生的proposal(候选框)。调用的是lib/datasets/factory.py中的get_imdb函数    imdb = get_imdb(args.imdb_name)    print 'Loaded dataset `{:s}` for training'.format(imdb.name)    #上步得到的数据imdb进一步制作成训练时的数据,主要是把图片翻转,扩充训练样本.    roidb = get_training_roidb(imdb)    output_dir = get_output_dir(imdb, None)    print 'Output will be saved to `{:s}`'.format(output_dir)#真正的训练函数,调用lib/fast_rcnn/train.py中的train_net函数    train_net(args.solver, roidb, output_dir,              pretrained_model=args.pretrained_model,              max_iters=args.max_iters)

2.训练数据读取主函数是imdb=get_imdb(args.imdb_name)函数,在下面前辈的博客中已经很清楚了,请参考。
http://www.cnblogs.com/louyihang-loves-baiyan/archive/2015/10/16/4885659.html
将训练数据读取到imdb变量中只是简单的将数据读入,并没有将样本标注为正负类,数据扩充等操作。紧接着调用 roidb = get_training_roidb(imdb)制作训练数据集,实现在lib/fast_rcnn/train.py中

def get_training_roidb(imdb):    """Returns a roidb (Region of Interest database) for use in training."""    if cfg.TRAIN.USE_FLIPPED:        print 'Appending horizontally-flipped training examples...'        imdb.append_flipped_images()#数据翻转操作,扩充训练数据集        print 'done'    print 'Preparing training data...'    rdl_roidb.prepare_roidb(imdb)#prepare_roidb函数中max_classes是每个proposal重合度最大的物体gt的类别,max_overlaps是最大重合度。    print 'done'    return imdb.roidb

3.lib/roi_data_layer下的网输入层
caffe提供了python,也就是说可以用python实现某一个层,fast_rcnn就用python实现了网络的输入层。
fast_rcnn中实现caffe支持python主要设置了两个文件:fast-rcnn/caffe-fast-rcnn/src/caffe.proto和fast-rcnn/caffe-fast-rcnn/include/caffe/python_layer.hpp文件
. caffe.proto中注册Python层参数:

  optional PythonParameter python_param = 130;........// Message that stores parameters used by PythonLayermessage PythonParameter {  optional string module = 1;  optional string layer = 2;  // This value is set to the attribute `param_str_` of your custom  // `PythonLayer` object in Python before calling `setup()` method. This could  // be a number, a string, a dictionary in Python dict format or JSON etc. You  // may parse this string in `setup` method and use them in `forward` and  // `backward`.  optional string param_str = 3 [default = ''];}

. python_layer.hpp头文件定义了需要python实现的几个输入层的重要函数:setup函数,reshape函数,forward函数,backward函数。

#ifndef CAFFE_PYTHON_LAYER_HPP_#define CAFFE_PYTHON_LAYER_HPP_#include <boost/python.hpp>#include <string>#include <vector>#include "caffe/layer.hpp"namespace bp = boost::python;namespace caffe {#define PYTHON_LAYER_ERROR() { \  PyObject *petype, *pevalue, *petrace; \  PyErr_Fetch(&petype, &pevalue, &petrace); \  bp::object etype(bp::handle<>(bp::borrowed(petype))); \  bp::object evalue(bp::handle<>(bp::borrowed(bp::allow_null(pevalue)))); \  bp::object etrace(bp::handle<>(bp::borrowed(bp::allow_null(petrace)))); \  bp::object sio(bp::import("StringIO").attr("StringIO")()); \  bp::import("traceback").attr("print_exception")( \    etype, evalue, etrace, bp::object(), sio); \  LOG(INFO) << bp::extract<string>(sio.attr("getvalue")())(); \  PyErr_Restore(petype, pevalue, petrace); \  throw; \}template <typename Dtype>class PythonLayer : public Layer<Dtype> { public:  PythonLayer(PyObject* self, const LayerParameter& param)      : Layer<Dtype>(param), self_(bp::handle<>(bp::borrowed(self))) { }  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,      const vector<Blob<Dtype>*>& top) {    try {      self_.attr("param_str_") = bp::str(        this->layer_param_.python_param().param_str());        #LayerSetUp函数调用setup函数,具体实现在lib/roi_data_layer/layer.py中的setup函数。下面reshape,forward,backward函数同理      self_.attr("setup")(bottom, top);    } catch (bp::error_already_set) {      PYTHON_LAYER_ERROR();    }  }  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,      const vector<Blob<Dtype>*>& top) {    try {      self_.attr("reshape")(bottom, top);    } catch (bp::error_already_set) {      PYTHON_LAYER_ERROR();    }  }  virtual inline const char* type() const { return "Python"; } protected:  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,      const vector<Blob<Dtype>*>& top) {    try {    #forward函数的实现在lib/roi_data_layer/layer.py中      self_.attr("forward")(bottom, top);    } catch (bp::error_already_set) {      PYTHON_LAYER_ERROR();    }  }  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {    try {    #backward函数的实现在lib/roi_data_layer/layer.py中      self_.attr("backward")(top, propagate_down, bottom);    } catch (bp::error_already_set) {      PYTHON_LAYER_ERROR();    }  } private:  bp::object self_;};}  // namespace caffe#endif
2 0
原创粉丝点击