用自己的数据训练Faster-RCNN,tensorflow版本(一)

来源:互联网 发布:java九九乘法表上三角 编辑:程序博客网 时间:2024/06/06 04:54

我用的Faster-RCNN是tensorflow版本,fork自githubFaster-RCNN_TF

参考
http://www.cnblogs.com/CarryPotMan/p/5390336.html

1、按照Faster-RCNN_TF中的步骤,先在本地完成编译。

1.1、环境配置

按照该项目中的README.md ,将需要的几个依赖cython, python-opencv, easydict都安装好,并确保本地计算机中有tensorflow,没有的话自行安装;

1.2、克隆工程:在本地计算机的终端输入

git clone --recursive https://github.com/smallcorgi/Faster-RCNN_TF.git

下载下来的内容都在目录 Faster-RCNN_TF 下;

1.3、编译Cython模块

cd $FRCN_ROOT/lib # 首先进入目录Faster-RCNN_TF/lib
make #编译

编译成功之后,目录Faster-RCNN_TF/lib/nms 和 Faster-RCNN_TF/lib/roi_pooling_layer/ 和Faster-RCNN_TF/lib/utils下面会出现一些.so文件。

注意:如果在这时候,你将该工程原封不动的连带着.so文件一起移植到了另一台电脑上,想重新运行程序的时候,记住,要先删除这几个.so文件,并重新进行编译。因为编译生成的文件是只适应本台计算机的,换一台计算机之后,用原来的.so文件,就行不通了,程序会出错。并且,必须要先删除旧的.so文件,否则就会调用旧的.so文件,而不生成新的.so文件。

2、介绍一下pascal_voc数据集的数据读写接口

工程Faster-RCNN_TF中读取数据的接口都在目录Faster-RCNN_TF/lib/datasets下。

原工程提供了5种数据来训练网络,并分别给出了各自的数据读写接口。
5种数据分别是pascal_voc,coco,kitti,nissan,nthu,各自的数据读写接口分别是Faster-RCNN_TF/lib/datasets 中的pascal_voc.py,coco.py,kitti.py,nissan.py,nthu.py。

我们可以看到Faster-RCNN_TF/lib/datasets目录下还有一些.py文件,分别是:
factory.py:是个工厂类,用类生成imdb类并且返回数据库供网络训练和测试使用
imdb.py:是数据库读写类的基类,分装了许多db的操作,具体的一些文件读写需要继承继续读写

我们要用自己的数据进行训练,就得编写自己数据的读写接口,下面参考pascal_voc.py来编写。

2.1、首先说明一下pascal_voc数据集的格式

以VOC2007为例,数据都放在一个叫做VOCdevkit的目录中,我们来看一下目录VOCdevkit的结构:

VOCdevkit/VOCdevkit/VOC2007/VOCdevkit/VOC2007/Annotations #所有图片的XML文件,一张图片对应一个XML文件,XML文件中给出的图片gt的形式是左上角和右下角的坐标VOCdevkit/VOC2007/ImageSets/          VOCdevkit/VOC2007/ImageSets/Layout #里面有三个txt文件,分别是train.txt,trainval.txt,val.txt,存储的分别是训练图片的名字列表,训练验证集的图片名字列表,验证集图片的名字列表(名字均没有.jpg后缀)VOCdevkit/VOC2007/ImageSets/MainVOCdevkit/VOC2007/ImageSets/SegmentationVOCdevkit/VOC2007/JPEGImages  #所有的图片VOCdevkit/VOC2007/SegmentationClass  #segmentations by classVOCdevkit/VOC2007/SegmentationObject  #segmentations by object

Faster-RCNN_TF工程主要用到的是目录Annotations中的XML文件、目录JPEGImages中的图片、目录ImageSets/Layout中的txt文件。

2.2、然后解释一下pascal_voc.py中每个的函数的作用
主函数 if name == ‘main在文件pascal_voc.py的最下面

if __name__ == '__main__':    from datasets.pascal_voc import pascal_voc    d = pascal_voc('trainval', '2007') #pascal_voc是一个类    res = d.roidb    from IPython import embed; embed()

类 pascal_voc中的函数:
class pascal_voc(imdb):
def init(self, image_set, year, devkit_path=None)在文件pascal_voc.py的最上面
是初始化函数,对应着的是pascal_voc的数据集访问格式
(我会按照这个初始化函数里面用到的子函数的顺序来介绍每个子函数的作用,这样看比较直观。在这个初始化函数init中用到的每个子函数我都会有一个标号,方便介绍。)

'''是初始化函数,对应着的是pascal_voc的数据集访问格式:param image_set: 是一个str,值是'train'或者'test'或者'trainval'或者'val',表示的意思是用(训练集)或者(测试集)或者(训练验证集)或者(验证集)里面的数据;:param year: 是一个str,是VOC数据的年份,值是'2007'或者'2012':param devkit_path: pascal_voc数据集所在的路径''''''以下的image_set都以train为例year都以2007为例'''def __init__(self, image_set, year, devkit_path=None):     imdb.__init__(self, 'voc_' + year + '_' + image_set) # 继承了类imdb的初始化函数__init__(),传进去的参数是voc_2007_train。类imdb在Faster-RCNN_TF_R2/lib/datasets/imdb.py里面被定义    self._year = year #年份,比如2007    self._image_set = image_set # train     self._devkit_path = self._get_default_path() if devkit_path is None else devkit_path # 这个路径是pascal_voc数据集所在的路径。如果devkit_path is None,返回pascal_voc的默认路径:目录VOCdevkit;如果devkit_path有值,则返回devkit_path。默认路径用函数_get_default_path()获得,标号(1)    self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)#就是VOCdevkit/VOC2007    self._classes = ('__background__', # always index 0                     'aeroplane', 'bicycle', 'bird', 'boat',                     'bottle', 'bus', 'car', 'cat', 'chair',                     'cow', 'diningtable', 'dog', 'horse',                     'motorbike', 'person', 'pottedplant',                     'sheep', 'sofa', 'train', 'tvmonitor') #数据集中所包含的全部的object类别    self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) # 构建字典{'__background__':'0','aeroplane':'1', 'bicycle':'2', 'bird':'3', 'boat':'4','bottle':'5', 'bus':'6', 'car':'7', 'cat':'8', 'chair':'9','cow':'10', 'diningtable':'11', 'dog':'12', 'horse':'13','motorbike':'14', 'person':'15', 'pottedplant':'16','sheep':'17', 'sofa':'18', 'train':'19', 'tvmonitor':'20'}  self.num_classes是object的类别总数21(背景也算一类),这个函数继承自Faster-RCNN_TF_R2/lib/datasets/imdb.py    self._image_ext = '.jpg' # 图片后缀名    self._image_index = self._load_image_set_index() #加载了样本的list文件,标号(2)    # Default to roidb handler    #self._roidb_handler = self.selective_search_roidb #当没有RPN的时候,读取并返回候选框ROI的db。函数selective_search_roidb是fast-rcnn提取候选框的方式(fast-rcnn没有RPN),下面会具体讲    self._roidb_handler = self.gt_roidb # 当有RPN的时候,读取并返回图片gt的db。函数gt_roidb里面并没有提取图片的ROI,因为faster-rcnn有RPN,用RPN来提取ROI。函数gt_roidb返回的是图片的gt。标号(3)    self._salt = str(uuid.uuid4())    self._comp_id = 'comp4'    # PASCAL specific config options    self.config = {'cleanup'     : True,                   'use_salt'    : True,                   'use_diff'    : False,                   'matlab_eval' : False,                   'rpn_file'    : None,                   'min_size'    : 2}    assert os.path.exists(self._devkit_path), \        'VOCdevkit path does not exist: {}'.format(self._devkit_path) #如果路径self._devkit_path(也就是目录VOCdevkit)不存在,退出    assert os.path.exists(self._data_path), \        'Path does not exist: {}'.format(self._data_path)  #如果路径self._data_path(也就是VOCdevkit/VOC2007)不存在,退出

标号(1)def _get_default_path(self)

def _get_default_path(self):    """    Return the default path where PASCAL VOC is expected to be installed.    返回数据集pascal_voc的默认路径:Faster-RCNN_TF/data/VOCdevkit/2007    """    return os.path.join(cfg.DATA_DIR, 'VOCdevkit') # cfg.DATA_DIR是在Faster-RCNN_TF/lib/fast_rcnn/config.py里面定义的,

Faster-RCNN_TF/lib/fast_rcnn/config.py中定义DATA_DIR的地方是这样的(在220-224行):

# Root directory of project__C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..')) # 这个路径指的就是目录Faster-RCNN_TF# Data directory__C.DATA_DIR = osp.abspath(osp.join(__C.ROOT_DIR, 'data')) # 这个路径是Faster-RCNN_TF/data

标号(2)def _load_image_set_index(self)

def _load_image_set_index(self):    """    Load the indexes listed in this dataset's image set file.    得到一个list,这个list里面是集合self._image_set中所有图片的名字(注意,图片名字没有后缀.jpg)    """    image_set_file = os.path.join(self._data_path, 'ImageSets', 'Layout',                                  self._image_set + '.txt')     # image_set_file就是Faster-RCNN_TF/data/VOCdevkit/VOC2007/ImageSets/Layout/train.txt    #之所以要读这个train.txt文件,是因为train.txt文件里面写的是集合train中所有图片的名字(没有后缀.jpg)    assert os.path.exists(image_set_file), \            'Path does not exist: {}'.format(image_set_file)    with open(image_set_file) as f: # 读上面的train.txt文件        image_index = [x.strip() for x in f.readlines()] #将train.txt的内容(图片名字)读取出来放在image_index里面    return image_index #得到image_set里面所有图片的名字(没有后缀.jpg)

标号(3)def gt_roidb(self)

def gt_roidb(self):    """    Return the database of ground-truth regions of interest.    读取并返回图片gt的db。这个函数就是将图片的gt加载进来。    其中,pascal_voc图片的gt信息在XML文件中(这个XML文件是pascal_voc数据集本身提供的)    并且,图片的gt被提前放在了一个.pkl文件里面。(这个.pkl文件需要我们自己生成,代码就在该函数中)    This function loads/saves from/to a cache file to speed up future calls.    之所以会将图片的gt提前放在一个.pkl文件里面,是为了不用每次都再重新读图片的gt,直接加载这个文件就可以了,可以提升速度。    """    cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl') #给.pkl文件起个名字。参数self.cache_path和self.name继承自类imdb,类imdb在Faster-RCNN_TF_R2/lib/datasets/imdb.py里面被定义    if os.path.exists(cache_file): # 如果这个.pkl文件存在(说明之前执行过本函数,生成了这个pkl文件)        with open(cache_file, 'rb') as fid: #打开            roidb = cPickle.load(fid) #将里面的数据加载进来        print '{} gt roidb loaded from {}'.format(self.name, cache_file)        return roidb #返回    # 如果这个.pkl文件不存在,说明是第一次执行本函数。    gt_roidb = [self._load_pascal_annotation(index)                 for index in self.image_index] #那么首先要做的就是获取图片的gt,函数_load_pascal_annotation的作用就是获取图片gt。标号(4)    with open(cache_file, 'wb') as fid:        cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL) #将图片的gt保存在.pkl文件里面    print 'wrote gt roidb to {}'.format(cache_file)    return gt_roidb

参数self.cache_path和self.name继承自类imdb,类imdb在Faster-RCNN_TF_R2/lib/datasets/imdb.py里面被定义。类imdb中定义函数self.cache_path的地方在imdb.py中的69-74行:

@propertydef cache_path(self):    cache_path = osp.abspath(osp.join(cfg.DATA_DIR, 'cache'))  # 该路径是Faster-RCNN_TF/data/cache    if not os.path.exists(cache_path):        os.makedirs(cache_path)    return cache_path

类imdb中定义函数self.name的地方在imdb.py中的21-36行:

def __init__(self, name): #是类imdb的初始化函数,在pascal_voc.py的第26行被用到    # name是形参,传进来的参数是'voc_2007_train' or ‘voc_2007_test’ or 'voc_2007_val' or 'voc_2007_trainval'    self._name = name # 'voc_2007_train' or ‘voc_2007_test’ or 'voc_2007_val' or 'voc_2007_trainval'    self._num_classes = 0    self._classes = []    self._image_index = []    self._obj_proposer = 'selective_search'    self._roidb = None    print self.default_roidb    self._roidb_handler = self.default_roidb  # self._roidb_handler在Faster-RCNN_TF/lib/datasets/icdar_2015.py中,又被重新定义了    # Use this dict for storing dataset specific config options    self.config = {}@propertydef name(self): #类imdb中定义函数self.name的地方    return self._name #返回的是本文件imdb.py中的self._name,往上面看

注意:如果你再次训练的时候修改了train数据库,增加或者删除了一些数据,再想重新训练的时候,一定要先删除这个.pkl文件!!!!!!因为如果不删除的话,就会自动加载旧的pkl文件,而不会生成新的pkl文件。一定别忘了!

标号(4)def _load_pascal_annotation(self, index):这个函数是读取图片gt的具体实现

def _load_pascal_annotation(self, index):   """   :param index: 一张图片的名字(没有后缀.jpg)   Load image and bounding boxes info from XML file in the PASCAL VOC   format.从XML文件中获取图片信息和gt。   这个XML文件存储的是PASCAL VOC图片的信息和gt的信息,我们在下载VOC数据集的时候,XML文件是一块下载下来的。在文件夹Annotation里面。   """   filename = os.path.join(self._data_path, 'Annotations', index + '.xml') #这个filename就是一个XML文件的路径,其中index是一张图片的名字(没有后缀)。例如VOCdevkit/VOC2007/Annotations/000005.xml   tree = ET.parse(filename)   objs = tree.findall('object')   if not self.config['use_diff']:       # Exclude the samples labeled as difficult       non_diff_objs = [           obj for obj in objs if int(obj.find('difficult').text) == 0]       # if len(non_diff_objs) != len(objs):       #     print 'Removed {} difficult objects'.format(       #         len(objs) - len(non_diff_objs))       objs = non_diff_objs   num_objs = len(objs)  # 输进来的图片上的物体object的个数   boxes = np.zeros((num_objs, 4), dtype=np.uint16)   gt_classes = np.zeros((num_objs), dtype=np.int32)   overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)   # "Seg" area for pascal is just the box area   seg_areas = np.zeros((num_objs), dtype=np.float32)   # Load object bounding boxes into a data frame.   for ix, obj in enumerate(objs): # 对于该图片上每一个object       bbox = obj.find('bndbox') # pascal_voc的XML文件中给出的图片gt的形式是左上角和右下角的坐标       # Make pixel indexes 0-based       x1 = float(bbox.find('xmin').text) - 1        y1 = float(bbox.find('ymin').text) - 1       x2 = float(bbox.find('xmax').text) - 1       y2 = float(bbox.find('ymax').text) - 1 #为什么要减去1?是因为VOC的数据,坐标-1,默认坐标从0开始(这个还有待商榷,先忽略)       cls = self._class_to_ind[obj.find('name').text.lower().strip()]#找到该object的类别       boxes[ix, :] = [x1, y1, x2, y2]       gt_classes[ix] = cls       overlaps[ix, cls] = 1.0       seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1) # seg_areas[ix]是该object gt的面积   overlaps = scipy.sparse.csr_matrix(overlaps)   return {'boxes' : boxes,           'gt_classes': gt_classes,           'gt_overlaps' : overlaps,           'flipped' : False,           'seg_areas' : seg_areas}

分析到现在,pascal_voc.py中还剩下一些函数,这些函数并没有在pascal_voc.py里面用到,但是在别的地方用到了,下面也分析一下:

def image_path_at(self, i)

'''根据第i个图像样本返回其对应的path,其调用了image_path_from_index(self, index)作为其具体实现;'''def image_path_at(self, i):    """    Return the absolute path to image i in the image sequence.    """    return self.image_path_from_index(self._image_index[i])

def image_path_from_index(self, index)

def image_path_from_index(self, index):    """    :param index: 是一张图片的名字,假如说有一张图片叫lsq.jpg,这个值就是lsq,没有后缀名    Construct an image path from the image's "index" identifier.    返回图片所在的路径    """    image_path = os.path.join(self._data_path, 'JPEGImages',                              index + self._image_ext) #这个就是图片本身所在的路径。其中index是一张图片的名字(没有后缀),_image_ext是图片后缀名.jpg。例如VOCdevkit/VOC2007/JPEGImages/000005.jpg    assert os.path.exists(image_path), \            'Path does not exist: {}'.format(image_path) # 如果该路径不存在,退出    return image_path

def selective_search_roidb(self)

def selective_search_roidb(self):    """    Return the database of selective search regions of interest.    Ground-truth ROIs are also included.    没有RPN的fast-rcnn提取候选框的方式。返回的是提取出来的ROI以及图片的gt。    这个函数在Faster-RCNN里面用不到,在fast-rcnn里面才会用到    This function loads/saves from/to a cache file to speed up future calls.    """    cache_file = os.path.join(self.cache_path,                              self.name + '_selective_search_roidb.pkl')    if os.path.exists(cache_file):        with open(cache_file, 'rb') as fid:            roidb = cPickle.load(fid)        print '{} ss roidb loaded from {}'.format(self.name, cache_file)        return roidb    if int(self._year) == 2007 or self._image_set != 'test':        gt_roidb = self.gt_roidb()        ss_roidb = self._load_selective_search_roidb(gt_roidb)        roidb = imdb.merge_roidbs(gt_roidb, ss_roidb)    else:        roidb = self._load_selective_search_roidb(None)    with open(cache_file, 'wb') as fid:        cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)    print 'wrote ss roidb to {}'.format(cache_file)    return roidb

def _load_selective_search_roidb(self, gt_roidb)

def _load_selective_search_roidb(self, gt_roidb):    '''    加载预选框的文件    这个函数在Faster-RCNN里面用不到,在fast-rcnn里面才会用到。这个我还没有研究    '''     filename = os.path.abspath(os.path.join(cfg.DATA_DIR,                                             'selective_search_data',                                             self.name + '.mat'))     assert os.path.exists(filename), \            'Selective search data not found at: {}'.format(filename)     raw_data = sio.loadmat(filename)['boxes'].ravel()     box_list = []     for i in xrange(raw_data.shape[0]):         boxes = raw_data[i][:, (1, 0, 3, 2)] - 1         keep = ds_utils.unique_boxes(boxes)         boxes = boxes[keep, :]         keep = ds_utils.filter_small_boxes(boxes, self.config['min_size'])         boxes = boxes[keep, :]         box_list.append(boxes)     return self.create_roidb_from_box_list(box_list, gt_roidb)

3、编写自己的数据读写接口

我们要用自己的数据进行训练,就得编写自己数据的读写接口,下面参考pascal_voc.py来编写。根据上面对pascal_voc.py文件的分析,发现,pascal_voc.py用了非常多的路径拼接,很麻烦,我们不用这么麻烦,只要设置好自己数据的路径就可以了。

详情见下篇。

原创粉丝点击