Faster R-CNN —— ROIDB和Minibatch函数实现分析

来源:互联网 发布:cnc数控编程代码大全 编辑:程序博客网 时间:2024/06/03 23:43
几点说明介绍:1.全文分成两部分,关于ROI database的函数分析和关于minibatch的函数分析2.文字描述函数调用关系有些乱,因此使用标题等级来体现函数之间的层次关系,标题等级高的函数调用/含有标题等级低的函数3.函数内容有的不太重要的部分做了简化或者省略,为了突出层次关系,函数内部的对其他函数内部的调用有些只是使用了`function(...)`的形式4.本文参考的代码基于tensorflow,代码链接为https://github.com/endernewton/tf-faster-rcnn

ROIDB

最顶层文件:trainval_net.py

def parse_args():                               #解析参数def combined_roidb(imdb_names):if __name__ == '__main__':

(1)combined_roidb()

def combined_roidb(imdb_names):  def get_roidb(imdb_name):                                      #内部函数    imdb = get_imdb(imdb_name)                                  (1)/lib/tools/factory.py    imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)         (2)/lib/dataset/imdb.py    roidb = get_training_roidb(imdb)                            (3)/lib/model/train_val.py    return roidb  roidbs = [get_roidb(s) for s in imdb_names.split('+')]        #the function above  roidb = roidbs[0]  if len(roidbs) > 1:    for r in roidbs[1:]:      roidb.extend(r)    tmp = get_imdb(imdb_names.split('+')[1])                        imdb = datasets.imdb.imdb(imdb_names, tmp.classes)            else:    imdb = get_imdb(imdb_names)  return imdb, roidb

(1.1)get_imdb()

​ 通过给数据集的名字返回该数据集对应的类的对象

(1.2)set_proposal_method()

​ 由于cfg.TRAIN.PROPOSAL_METHOD = ‘tf’,内部相当于执行了一次gt_roidb函数:

def gt_roidb():                     #位于/lib/tools/pascal_voc.py    ...    gt_roidb = [self._load_pascal_annotation(index)                         for index in self.image_index]    #关键语句,调用同一文件下面的_load_pascal_annotation函数    #该函数从XML文件中加载图片和bbox    ...    #最后返回一个字典,包含“boxes”,“gt_classes”等

(1.3)get_training_roidb()

def get_training_roidb(imdb):  """Returns a roidb (Region of Interest database) for use in training."""  if cfg.TRAIN.USE_FLIPPED:    imdb.append_flipped_images()                    #通过翻转增加样本数量,位于/lib/dataset/imdb.py  rdl_roidb.prepare_roidb(imdb)                    (1)/lib/roi_data_layer/roidb.py  return imdb.roidb

(1.3.1)prepare_roidb()

def prepare_roidb(imdb):                            "为roidb加了一些说明性的属性"    for i in range(len(imdb.image_index)):        roidb[i]['image'/'width'/'height'/'max_classes'/'max_overlaps'...]

output:

roidb[img_index]包含的key value boxes box位置信息,box_num*4的np array gt_overlaps 所有box在不同类别的得分,box_num*class_num矩阵 gt_classes 所有box的真实类别,box_num长度的list flipped 是否翻转 image 该图片的路径,字符串 width 图片的宽 height 图片的高 max_overlaps 每个box的在所有类别的得分最大值,box_num长度 max_classes 每个box的得分最高所对应的类,box_num长度 bbox_targets 每个box的类别,以及与最接近的gt-box的4个方位偏移

(2)if name == ‘main‘:

if name == 'main':    args = parse_args()                                         #the function above    ...    combined_roidb(...)                                         #the function above    ...    #build network    net = vgg16/resnetv1(num_layers=50/101/152)/mobilenetv1     (2.1)/lib/nets/network.py    train_net(...)                                              (2.2)/lib/model/train_val.py

​ 主函数首先调用之前定义的parse_args()、combined_roidb(),每个数据集返回了带有该数据集信息的imdb和每张影像的roidb

接下来(2.1)步构建网络

(2.1)RPN和ROI-Pooling网络的构建

​ 以vgg16为例

class vgg16(Network):                                   #继承与Network的子类    ...    def _build_network():                               #核心函数        ...        net = slim.repeat/max_pool2d                    #构建CNN网络        ...        # build the anchors for the image        self._anchor_component()        # 构建RPN        rois = self._region_proposal(net, is_training, initializer)        # 构建ROI-Pooling        if cfg.POOLING_MODE == 'crop':          pool5 = self._crop_pool_layer(net, rois, "pool5")        else:          raise NotImplementedError    ...

(2.2)train_net()

def train_net(...):                                 #训练网络    filter_roidb(...)                                       #对roi进行筛选,去掉没有用的,筛选标准为:    # Valid images have:    #   (1) At least one foreground RoI OR    #   (2) At least one background RoI    ...    sw = SolverWrapper(...)                         #构造一个SolverWrapper类的对象,用于训练    sw.train_model(sess, max_iters)                 #

train_model函数和batch

def train_model(self, sess, max_iters):    # 为训练和验证构造RoIDataLayer的对象    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)       #/lib/roi_data_layer/layer.py    self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)    ...    while iter < max_iters + 1:      # 训练数据的时候用随机梯度下降,一次获取一个batch的数据,获取方法调用forward函数      blobs = self.data_layer.forward()                                           #forward函数位于 /lib/roi_data_layer/layer.py ,内部仅仅调用_get_next_minibatch函数

最高级:layer.py

def _shuffle_roidb_inds(self):              #洗牌函数,打乱database顺序def _get_next_minibatch_inds(self):         #如果if条件满足,用shuffle函数打乱顺序并选出新一组batch的index并返回    if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb):      self._shuffle_roidb_inds()def _get_next_minibatch(self):    _get_next_minibatch_inds(...)           #得到新一组batch的index    get_minibatch(...)                      #调用中间级函数,根据上面得到的index读出图像,def forward(self):    blobs = self._get_next_minibatch()      #顶层函数

中间级:minibatch.py

def get_minibatch(roidb, num_classes):    “根据提供的roidb,调用_get_image_blob读取图像数据,并随机选择构造出一个minibatch样本,被layer.py里的函数调用”    im_blob, im_scales = _get_image_blob(roidb, random_scale_inds)    return blobsdef _get_image_blob(roidb, scale_inds):    for i in range(num_images):        im = cv2.imread(roidb[i]['image'])        prep_im_for_blob(...)               #调用最底层文件里的函数    ...    blob = im_list_to_blob(processed_ims)   #调用最底层文件里的函数

最底层:/lib/utils/blob.py

这个文件里的函数主要用于将图像构造成方便训练的blob类型数据结构def im_list_to_blob(ims):                                               #输入一个ims,将其转化为4维array形式的blobdef prep_im_for_blob(im, pixel_means, target_size, max_size):           #求取图像的缩放比例,然后将图像resize