使用faster rcnn训练自己的模型

来源:互联网 发布:极品五笔2013优化版 编辑:程序博客网 时间:2024/05/01 17:27

    • 安装caffe
    • 数据预处理
    • 对于训练代码的修改
    • 遇到问题
    • 参考性很强

转载请注明:http://blog.csdn.net/c602273091/article/details/53945485

安装caffe

可以看我之前的博客:
安装caffe
安装faster rcnn:
faster rcnn

数据预处理

进行数据标注:
https://github.com/saicoco/object_labelImg

我这里主要是使用python对xml进行处理。

生产xml的代码:

# -*- coding:utf-8 -*-__author__ = "Yu Chen"import xml.domimport xml.dom.minidomimport osimport jsonimport scipyimport numpyimport matplotlibfrom PIL import Image_ANNOTATION_SAVE_FOLDER_ = 'Annotations'# xml文件规范定义_INDENT = '\t'#' ' * 4_NEW_LINE = '\n'_FOLDER_NODE = 'VOC2007'_ROOT_NODE = 'annotation'_DATABASE_NAME = 'INRIA'_ANNOTATION = 'PASCAL VOC2007'_AUTHOR = 'Yu Chen'_SEGMENTED = '0'_DIFFICULT = '0'_TRUNCATED = '0'_POSE = 'Unspecified'_IMAGE_HEIGHT = 360_IMAGE_WIDTH = 640_IMAGE_CHANNEL = 3# 封装创建节点的过程def createElementNode(doc, tag, attr):    # 创建一个元素节点    element_node = doc.createElement(tag)    # 创建一个文本节点    text_node = doc.createTextNode(attr)   # 将文本节点作为元素节点的子节点    element_node.appendChild(text_node)    return element_node# 封装添加一个子节点的过程def createChildNode(doc, tag, attr, parent_node):    child_node = createElementNode(doc, tag, attr)    parent_node.appendChild(child_node)# object节点比较特殊def createObjectNode(doc, attrs):    object_node = doc.createElement('object')    createChildNode(doc, 'name', attrs['classification'], object_node)    createChildNode(doc, 'pose', _POSE, object_node)    createChildNode(doc, 'truncated', _TRUNCATED, object_node)    createChildNode(doc, 'difficult', _DIFFICULT, object_node)    bndbox_node = doc.createElement('bndbox')    createChildNode(doc, 'xmin', attrs['xmin'], bndbox_node)    createChildNode(doc, 'ymin', attrs['ymin'], bndbox_node)    createChildNode(doc, 'xmax', attrs['xmax'], bndbox_node)    createChildNode(doc, 'ymax', attrs['ymax'], bndbox_node)    object_node.appendChild(bndbox_node)    return object_node# 将documentElement写入XML文件中def writeXMLFile(doc, filename):    tmpfile = open('tmp.xml', 'w')    doc.writexml(tmpfile, addindent=_INDENT, newl='\n', encoding='utf-8')    tmpfile.close()    # 删除第一行默认添加的标记    fin = open('tmp.xml')    fout = open(filename, 'w')   fout = open(filename, 'w')    lines = fin.readlines()    for line in lines[1:]:        if line.split():            fout.writelines(line)    #new_lines = ''.join(lines[1:])    #fout.write(new_lines)    fin.close()    fout.close()# 创建XML文档并写入节点信息def createXMLFile(attrs, width, height, filename):    # 创建文档对象, 文档对象用于创建各种节点    my_dom = xml.dom.getDOMImplementation()    doc = my_dom.createDocument(None, _ROOT_NODE, None)    # 获得根节点    root_node = doc.documentElement    # folder节点    createChildNode(doc, 'folder', _FOLDER_NODE, root_node)    # filename节点    createChildNode(doc, 'filename', attrs['name'], root_node)    # source节点    source_node = doc.createElement('source')    # source的子节点    createChildNode(doc, 'database', _DATABASE_NAME, source_node)    createChildNode(doc, 'annotation', _ANNOTATION, source_node)    createChildNode(doc, 'image', 'flickr', source_node)    createChildNode(doc, 'flickrid', 'NULL', source_node)    root_node.appendChild(source_node)    # owner节点    owner_node = doc.createElement('owner')    # owner的子节点    createChildNode(doc, 'flickrid', 'NULL', owner_node)    createChildNode(doc, 'name', _AUTHOR, owner_node)       root_node.appendChild(owner_node)    # size节点    size_node = doc.createElement('size')    createChildNode(doc, 'width', str(width), size_node)    createChildNode(doc, 'height', str(height), size_node)    createChildNode(doc, 'depth', str(_IMAGE_CHANNEL), size_node)    root_node.appendChild(size_node)    # segmented节点    createChildNode(doc, 'segmented', _SEGMENTED, root_node)    # object节点    object_node = createObjectNode(doc, attrs)    root_node.appendChild(object_node)    # 写入文件    writeXMLFile(doc, filename)if __name__ == "__main__":    # open label    fid = open('training/label.idl', 'r')    # storage path    if not os.path.exists('Annotations/'):        os.mkdir('Annotations')    while True:        line = fid.readline()        if line:            data = json.loads(line)            for ite_key in data.keys():                #print ite_key                 attrs = dict()                attrs['name'] = str(ite_key)                xml_file_name = os.path.join(_ANNOTATION_SAVE_FOLDER_, (attrs['name'].split('.'))[0] + '.xml')                print xml_file_name                if data[ite_key]:                    for bbx in data[ite_key]:                        attrs['xmin'] = str(bbx[0])                        attrs['ymin'] = str(bbx[1])                        attrs['xmax'] = str(bbx[2])                        attrs['ymax'] = str(bbx[3])                        attrs['classification'] = str(bbx[4])                        if os.path.exists(xml_file_name):                            # print('do exists')                            existed_doc = xml.dom.minidom.parse(xml_file_name)                            root_node = existed_doc.documentElement                            # 如果XML存在了, 添加object节点信息即可                            object_node = createObjectNode(existed_doc, attrs)                            root_node.appendChild(object_node)                            # 写入文件                            writeXMLFile(existed_doc, xml_file_name)                    #       print bbx[0], bbx[1], bbx[2], bbx[3], bbx[4]                        else:                          # print('not exists')                          # 如果XML文件不存在, 创建文件并写入节点信息                          # 创建XML文件                          createXMLFile(attrs, _IMAGE_WIDTH, _IMAGE_HEIGHT, xml_file_name)                else:               #     createEmptyXMLFile(attrs, _IMAGE_WIDTH, _IMAGE_HEIGHT, xml_file_name)                    print "Empty List"        else:            break    fid.close()

生产Main的txt代码:

# -*- coding:utf-8 -*-import osimport random__author__ = 'Yu Chen''''设置trainval和test数据集包含的图片'''# ImageSets文件夹_IMAGE_SETS_PATH = 'ImageSets'_MAin_PATH = 'ImageSets/Main'_XML_FILE_PATH = 'Annotations'# Train数据集编号_TRAIN_NUMBER = 6000_TEST_NUM = 70091 # 72090if __name__ == '__main__':    resul = range(60091, 70091)    random.shuffle(resul)    # 创建ImageSets数据集    if os.path.exists(_IMAGE_SETS_PATH):        print('ImageSets dir is already exists')        if os.path.exists(_MAin_PATH):            print('Main dir is already in ImageSets')        else:            os.mkdir(_MAin_PATH)    else:        os.mkdir(_IMAGE_SETS_PATH)        os.mkdir(_MAin_PATH)    f_test = open(os.path.join(_MAin_PATH, 'test.txt'), 'w')    f_trainval = open(os.path.join(_MAin_PATH, 'trainval.txt'), 'w')    f_train = open(os.path.join(_MAin_PATH, 'train.txt'), 'w')    f_val = open(os.path.join(_MAin_PATH, 'val.txt'), 'w')    num = 0    for root, dirs, files in os.walk(_XML_FILE_PATH):        print len(files)        for f in files:            element = f.split('.')[0]            f_trainval.write(str(element)+'\n')            if num > _TRAIN_NUMBER:                f_val.write(str(element) + '\n')            else:                f_train.write(str(element) + '\n')            num += 1    for i in range(_TEST_NUM, 72091):        f_test.write(str(i) + '\n')    f_test.close()    f_trainval.close()    f_train.close()    f_val.close()

主要参考了:
1、http://blog.csdn.net/sinat_30071459/article/details/50723212
2、http://blog.csdn.net/gvfdbdf/article/details/52214008
3、https://github.com/Parlefan/create-voc2007-dataset/blob/master/create_ImageSets.py
4、https://github.com/Parlefan/create-voc2007-dataset/blob/master/create_JPEGImages.py
5、https://saicoco.github.io/object-detection-4/
6、http://www.cnblogs.com/louyihang-loves-baiyan/p/4885659.html
7、http://www.cnblogs.com/louyihang-loves-baiyan/p/4903231.html

对于训练代码的修改

主要是参考了:
http://blog.csdn.net/sinat_30071459/article/details/51332084

1、http://www.voidcn.com/blog/sinat_30071459/article/p-5957360.html
2、http://www.cnblogs.com/CarryPotMan/p/5390336.html

遇到问题

1、error 1:assert (boxes[:, 2] >= boxes[:, 0]).all()
将py-faster-rcnn/lib/datasets/imdb.py中的相应代码改成如下代码即可:

def append_flipped_images(self):        num_images = self.num_images        widths = [PIL.Image.open(self.image_path_at(i)).size[0]                  for i in xrange(num_images)]        for i in xrange(num_images):            boxes = self.roidb[i]['boxes'].copy()            oldx1 = boxes[:, 0].copy()            oldx2 = boxes[:, 2].copy()            boxes[:, 0] = widths[i] - oldx2 - 1            boxes[:, 2] = widths[i] - oldx1 - 1            for b in range(len(boxes)):                if boxes[b][2] < boxes[b][0]:                   boxes[b][0] = 0            assert (boxes[:, 2] >= boxes[:, 0]).all()

2、IndexError: list index out of range

删除fast-rcnn-master/data/cache/ 文件夹下的.pkl文件,或者改名备份,重新训练即可。

3、image_num aeert divide 0.
这是因为在做xml的时候,没有目标的图片不能记录。

参考了:
1、https://github.com/rbgirshick/py-faster-rcnn/issues
2、https://github.com/rbgirshick/fast-rcnn/issues/
3、http://blog.csdn.net/marshwb/article/details/50451548
4、http://blog.csdn.net/sinat_30071459/article/details/51332084
5、http://blog.csdn.net/xzzppp/article/details/52036794

参考性很强

有自己的数据集,很实用的。
http://www.cnblogs.com/louyihang-loves-baiyan/p/4906690.html
http://blog.csdn.net/sinat_30071459/article/details/50723212
http://download.csdn.net/detail/sinat_30071459/9531172
http://download.csdn.net/detail/sinat_30071459/9532108
https://saicoco.github.io/object-detection-4/
http://blog.csdn.net/sinat_30071459/article/details/51332084

1 0