faster rcnn 中pascal_voc.py

来源:互联网 发布:手机日记本软件 编辑:程序博客网 时间:2024/05/01 02:52

该部分代码功能在于实现了一个pascol _voc的类,该类继承自imdb,用于负责数据交互部分。

初始化函数

在初始化自身的同时,先调用了父类的初始化方法,将imdb _name传入,例如(‘voc_2007_trainval’),下面是成员变量的初始化:

{    year:’2007’    image _set:’trainval’    devkit_path:’data/VOCdevkit2007’    data _path:’data/VOCdevkit2007/VOC2007’    classes:(…)_如果想要训练自己的数据,需要修改这里_    class _to _ind:{…} _一个将类名转换成下标的字典 _    image _ext:’.jpg’    image _index:[‘000001’,’000003’,……]_根据trainval.txt获取到的image索引_    roidb _handler:<Method gt_roidb >    salt:  <Object uuid >    comp _id:’comp4’    config:{…}}


class pascal _voc(imdb):  def __init__(self,image_set, year, devkit_path=None):      imdb.__init__(self,'voc_' + year + '_' + image_set)      self._year = year      self._image_set =image_set      self._devkit_path =self._get_default_path() if devkit_path is None                          else devkit_path      self._data_path =os.path.join(self._devkit_path, 'VOC' + self._year)      self._classes = ('__background__',# always index 0                      'aeroplane', 'bicycle', 'bird', 'boat',                      'bottle', 'bus', 'car', 'cat', 'chair',                      'cow', 'diningtable', 'dog', 'horse',                      'motorbike', 'person', 'pottedplant',                      'sheep', 'sofa', 'train', 'tvmonitor')      self._class_to_ind =dict(zip(self.classes, xrange(self.num_classes)))      self._image_ext ='.jpg'      self._image_index =self._load_image_set_index()      # Default to roidb handler      self._roidb_handler =self.selective_search_roidb      self._salt =str(uuid.uuid4())      self._comp_id ='comp4'       # PASCAL specificconfig options      self.config ={'cleanup'     : True,                     'use_salt'    : True,                    'use_diff'    : False,                    'matlab_eval' : False,                    'rpn_file'    : None,                    'min_size'    : 2}       assertos.path.exists(self._devkit_path),               'VOCdevkit path does not exist:{}'.format(self._devkit_path)      assertos.path.exists(self._data_path),               'Path does notexist: {}'.format(self._data_path)


image_path _from _index

以下两个函数非常容易理解,就是根据图片的索引,比如‘000001’获取在JPEGImages下对应的图片路径

def image_path_at(self, i):        """        Return the absolutepath to image i in the image sequence.        """        returnself.image_path_from_index(self._image_index[i])     defimage_path_from_index(self, index):        """        Construct an imagepath from the image's "index" identifier.        """        image_path =os.path.join(self._data_path, 'JPEGImages',                                 index + self._image_ext)        assertos.path.exists(image_path), \                'Path does not exist: {}'.format(image_path)        return image_path# load _image _set _index# 该函数根据/VOCdevkit2007/VOC2007/ImageSets/Main/<image _set >.txt加载图像的索引    def_load_image_set_index(self):        """        Load the indexeslisted in this dataset's image set file.        """        # Example path toimage set file:        # self._devkit_path+ /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt        image_set_file =os.path.join(self._data_path, 'ImageSets', 'Main',                                      self._image_set + '.txt')        assertos.path.exists(image_set_file), \                'Path doesnot exist: {}'.format(image_set_file)        withopen(image_set_file) as f:            image_index =[x.strip() for x in f.readlines()]        return image_index

 

_get_default_path

返回默认的数据源路径,这里是放在data下的VOCDevkit2007,如果有自己的数据集,修改该函数即可

 def_get_default_path(self):        """        Return the defaultpath where PASCAL VOC is expected to be installed.        """        returnos.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)


gt_roidb

这个函数是该对象的核心函数之一,它将返回roidb数据对象。首先它会在cache路径下找到以扩展名’.pkl’结尾的缓存,这个文件是通过cPickle工具将roidb序列化存储的。如果该文件存在,那么它会先读取这里的内容,以提高效率(所以如果你换数据集的时候,要先把cache文件给删除,否则会造成错误)。接着,它将调用 _load _pascal _annotation这个私有函数加载roidb中的数据,并将其保存在缓存文件中,返回roidb。roidb的格式可以参考下文 _load_pascal _annotation的注释

def gt_roidb(self):        """        Return the databaseof ground-truth regions of interest.         This functionloads/saves from/to a cache file to speed up future calls.        """        cache_file =os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')        ifos.path.exists(cache_file):            withopen(cache_file, 'rb') as fid:                roidb =cPickle.load(fid)            print '{} gt roidbloaded from {}'.format(self.name, cache_file)            return roidb         gt_roidb =[self._load_pascal_annotation(index)                    forindex in self.image_index]        withopen(cache_file, 'wb') as fid:            cPickle.dump(gt_roidb,fid, cPickle.HIGHEST_PROTOCOL)        print 'wrote gtroidb to {}'.format(cache_file)        return gt_roidb


selective_search _roidb

这个函数在fasterrcnn中似乎不怎么用到,它也将返回roidb数据对象。首先它同样会在cache路径下找到以扩展名’.pkl’结尾的缓存,如果该文件存在,那么它会先读取这里的内容,以提高效率(如果你换数据集的时候,要先把cache文件给删除,否则会造成错误)。接着,它将调用同时调用gt _roidb()和 _load _selective_search _roidb()获取到两组roidb,再通过merge_roidbs将其合并,最后写入缓存并返回。

def selective_search_roidb(self):        """        Return the databaseof selective search regions of interest.        Ground-truth ROIs are also included.         This functionloads/saves from/to a cache file to speed up future calls.        """        cache_file =os.path.join(self.cache_path,                                 self.name + '_selective_search_roidb.pkl')         ifos.path.exists(cache_file):            withopen(cache_file, 'rb') as fid:                roidb =cPickle.load(fid)            print '{} ssroidb loaded from {}'.format(self.name, cache_file)            return roidb         if int(self._year)== 2007 or self._image_set != 'test':            gt_roidb =self.gt_roidb()            ss_roidb =self._load_selective_search_roidb(gt_roidb)            roidb =imdb.merge_roidbs(gt_roidb, ss_roidb)        else:            roidb = self._load_selective_search_roidb(None)        withopen(cache_file, 'wb') as fid:           cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)        print 'wrote ssroidb to {}'.format(cache_file)        return roidb


_load_selective_search _roidb

selective _search的方法,fasterrcnn一般不使用,暂时可以忽略

def _load_selective_search_roidb(self, gt_roidb):        filename =os.path.abspath(os.path.join(cfg.DATA_DIR,                                               'selective_search_data',                                                self.name + '.mat'))        assertos.path.exists(filename), \               'Selectivesearch data not found at: {}'.format(filename)        raw_data =sio.loadmat(filename)['boxes'].ravel()         box_list = []        for i inxrange(raw_data.shape[0]):            boxes =raw_data[i][:, (1, 0, 3, 2)] - 1            keep =ds_utils.unique_boxes(boxes)            boxes =boxes[keep, :]            keep =ds_utils.filter_small_boxes(boxes, self.config['min_size'])            boxes =boxes[keep, :]           box_list.append(boxes)        return self.create_roidb_from_box_list(box_list, gt_roidb)


_load_pascal _annotation

该函数根据每个图像的索引,到Annotations这个文件夹下去找相应的xml标注数据,然后加载所有的bounding box对象,并去除所有的“复杂”对象。

xml的解析到此结束,接下来是roidb中的几个类成员的赋值:

-  boxes 一个二维数组   每一行存储xminymin xmax ymax

-  gt _classes存储了每个box所对应的类索引(类数组在初始化函数中声明)

-  overlap是一个二维数组,共有num _classes(即类的个数)行,每一行对应的box的类索引处值为1,其余皆为0,后来被转成了稀疏矩阵

-  seg _areas存储着某个box的面积

-  flipped 为false代表该图片还未被翻转(后来在train.py里会将翻转的图片加进去,用该变量用于区分)

最后将这些成员变量组装成roidb返回

def _load_pascal_annotation(self, index):        """        Load image andbounding boxes info from XML file in the PASCAL VOC        format.        """        filename =os.path.join(self._data_path, 'Annotations', index + '.xml')        tree =ET.parse(filename)        objs =tree.findall('object')        if notself.config['use_diff']:            # Exclude thesamples labeled as difficult            non_diff_objs =[                obj for objin objs if int(obj.find('difficult').text) == 0]            # iflen(non_diff_objs) != len(objs):            #     print 'Removed {} difficultobjects'.format(            #         len(objs) - len(non_diff_objs))            objs = non_diff_objs        num_objs = len(objs)         boxes =np.zeros((num_objs, 4), dtype=np.uint16)        gt_classes =np.zeros((num_objs), dtype=np.int32)        overlaps =np.zeros((num_objs, self.num_classes), dtype=np.float32)        # "Seg"area for pascal is just the box area        seg_areas =np.zeros((num_objs), dtype=np.float32)         # Load objectbounding boxes into a data frame.        for ix, obj inenumerate(objs):            bbox =obj.find('bndbox')            # Make pixelindexes 0-based            x1 =float(bbox.find('xmin').text) - 1            y1 =float(bbox.find('ymin').text) - 1            x2 =float(bbox.find('xmax').text) - 1            y2 =float(bbox.find('ymax').text) - 1            cls =self._class_to_ind[obj.find('name').text.lower().strip()]            boxes[ix, :] =[x1, y1, x2, y2]            gt_classes[ix] =cls            # 从anatation直接载入图像的信息,因为本身就是ground-truth , 所以overlap直接设为1            overlaps[ix,cls] = 1.0            seg_areas[ix] =(x2 - x1 + 1) * (y2 - y1 + 1)        # overlaps为 num_objs * K 的数组, K表示总共的类别数, num_objs表示当前这张图片中box的个数        overlaps =scipy.sparse.csr_matrix(overlaps)         return {'boxes' :boxes,               'gt_classes': gt_classes,                'gt_overlaps' : overlaps,                'flipped' :False,                'seg_areas': seg_areas}

test

以下一些函数是测试结果所用,阅读价值不大,理解其功能即可

  def_write_voc_results_file(self, all_boxes):  def _do_python_eval(self,output_dir = 'output'):  def evaluate_detections(self,all_boxes, output_dir):


rpn_roidb

在经过RPN网络产生了proposal以后,这个函数作用是将这些proposal 的 roi与groudtruth结合起来,送入网络训练。

那怎么个结合法呢?proposal 的roidb格式与上面提到的gt_roidb一模一样,只不过overlap由1变成了与最接近的class的重合度。

如何判断是最接近的class呢?每个proposal的box都与groud-truth的box做一次重合度计算,与anchor_target _layer.py中类似

overlap = (重合部分面积) / (proposal _box面积 +gt_boxes面积 - 重合部分面积)

对于每个proposal,选出最大的那个gt_boxes的值,然后填到相应的class index下。

举个例子:

classes: backgroud  cat  fish dog  car  bedproposal1    0     0.65  0     0    0   0proposal2    0       0   0    0.8   0    0

原来对应的class下的1 变成了overlap值罢了。最后用merge_roidbs将gr_roidb与rpn _roidb合并,输出

 def rpn_roidb(self):        if int(self._year)== 2007 or self._image_set != 'test':            gt_roidb =self.gt_roidb()            # 求取rpn_roidb需要以gt_roidb作为参数才能得到            rpn_roidb =self._load_rpn_roidb(gt_roidb)            roidb =imdb.merge_roidbs(gt_roidb, rpn_roidb)        else:            roidb =self._load_rpn_roidb(None)        return roidb     def_load_rpn_roidb(self, gt_roidb):#调用父类方法create_roidb_from_box_list从box_list 中读取每张图像的boxes        filename =self.config['rpn_file']        print 'loading{}'.format(filename)        assertos.path.exists(filename), \               'rpn data notfound at: {}'.format(filename)        with open(filename,'rb') as f:            # 读取rpn_file里的box,形成box_list;box_list为一个列表,每张图像对应其中的一个元素,            # 所以box_list 的大小要与gt_roidb 相同            box_list =cPickle.load(f)        return self.create_roidb_from_box_list(box_list, gt_roidb)


测试所用

if __name__ == '__main__':  from datasets.pascal_vocimport pascal_voc  d = pascal_voc('trainval','2007')  res = d.roidb  from IPython import embed;embed()


0 0
原创粉丝点击