Faster-Rcnn训练自己的数据集

来源:互联网 发布:中韩贸易数据 编辑:程序博客网 时间:2024/04/30 04:32

1、首先,先从github下载faster-rcnn的安装包
https://github.com/rbgirshick/py-faster-rcnn
根据里面的Readme配置好caffe环境。试运行demo是否成功。如果,成功则进行下面的步骤,否 则,说明caffe配置失败。这个不属于这篇文章讨论范畴,可以自行查看文献,正确安装。


2、处理自己的数据集
为了不麻烦,出现各种路径bug。所以,建议直接在原有的 ../py-faster-rcnn/data/VOCdevkit2007/目录下添加修改自己的数据集。这里提醒一下,如果,按照 官方ln -s $VOCdevkit VOCdevkit2007命令,操作没有成功,则建议直接将VOCdevkit改名字为VOCdevkit2007。
这里只介绍检测,不介绍分割。所以,需要更改的有../py-faster-rcnn/data/VOCdevkit2007/VOC2007/文件下的三个文件内容。
- Annotations 目标的boudingbox的xml文件
- ImageSets 需要更改Layout 和Main两个目录,内容里面是相同的包含四个.txt文件;train.txt trainval.txt val.txt test.txt从文件名可以看出是干嘛的。
- JPEGImages 训练和测试使用的图片目录
下面是我使用的根据.txt文件生成.xml文件的matlab程序;如果,你自己已经有.xml文件了,则跳过此步骤。
我的.txt内容格式如下,因为我只有检测行人一类。所以,我就不用设置类别了。

002566.jpg 237 170 251 199文件名       左上角,右下角
%%%该代码可以做voc2007数据集中的xml文件,%txt文件每行格式为:000001.jpg dog 48 240 195 371%即每行由图片名、目标类型、包围框坐标组成,空格隔开%如果一张图片有多个目标,则格式如下:(比如两个目标)% 000001.jpg dog 48 240 195 371% 000001.jpg person 8 12 352 498% 000002.jpg train 139 200 207 301% 000003.jpg sofa 123 155 215 195% 000003.jpg chair 239 156 307 205%包围框坐标为左上角和右下角%%clc;clear;%注意修改下面四个变量%imgpath='/home/ubuntu/data/VOCdevkit/shujv/JPEGImages/';%图像存放文件夹imgpath='/home/zsj/data/py-faster-rcnn/data/VOCdevkit/VOC2007/JPEGImages/';txtpath='person.txt';%txtpath='/home/ubuntu/data/VOCdevkit/shujv/mytest.txt';%txt文件xmlpath_new='./';% mkdir(xmlpath_new)% xmlpath_new='xml/';%修改后的xml保存文件夹foldername='VOC2012';%xml的folder字段名fidin=fopen(txtpath,'r');lastname='begin';while ~feof(fidin)     tline=fgetl(fidin);     str = regexp(tline, ' ','split');     filepath=[imgpath,str{1}];     img=imread(filepath);     [h,w,d]=size(img);      imshow(img);      rectangle('Position',[str2double(str{2}),str2double(str{3}),str2double(str{4})-str2double(str{2}),str2double(str{5})-str2double(str{3})],'LineWidth',4,'EdgeColor','r');%       pause(0.1);        if strcmp(str{1},lastname)%如果文件名相等,只需增加object           object_node=Createnode.createElement('object');           Root.appendChild(object_node);           node=Createnode.createElement('name');           node.appendChild(Createnode.createTextNode(sprintf('%s','person')));           object_node.appendChild(node);           node=Createnode.createElement('pose');           node.appendChild(Createnode.createTextNode(sprintf('%s','Unspecified')));           object_node.appendChild(node);           node=Createnode.createElement('truncated');           node.appendChild(Createnode.createTextNode(sprintf('%s','0')));           object_node.appendChild(node);           node=Createnode.createElement('difficult');           node.appendChild(Createnode.createTextNode(sprintf('%s','0')));           object_node.appendChild(node);           bndbox_node=Createnode.createElement('bndbox');           object_node.appendChild(bndbox_node);           node=Createnode.createElement('xmin');           node.appendChild(Createnode.createTextNode(sprintf('%s',num2str(str{2}))));           bndbox_node.appendChild(node);           node=Createnode.createElement('ymin');           node.appendChild(Createnode.createTextNode(sprintf('%s',num2str(str{3}))));           bndbox_node.appendChild(node);           node=Createnode.createElement('xmax');           node.appendChild(Createnode.createTextNode(sprintf('%s',num2str(str{4}))));           bndbox_node.appendChild(node);           node=Createnode.createElement('ymax');           node.appendChild(Createnode.createTextNode(sprintf('%s',num2str(str{5}))));           bndbox_node.appendChild(node);        else %如果文件名不等,则需要新建xml           copyfile(filepath, 'JPEGImages');            %先保存上一次的xml           if exist('Createnode','var')              tempname=lastname;              tempname=strrep(tempname,'.jpg','.xml');              xmlwrite(tempname,Createnode);              end            Createnode=com.mathworks.xml.XMLUtils.createDocument('annotation');            Root=Createnode.getDocumentElement;%根节点            node=Createnode.createElement('folder');            node.appendChild(Createnode.createTextNode(sprintf('%s',foldername)));            Root.appendChild(node);            node=Createnode.createElement('filename');            node.appendChild(Createnode.createTextNode(sprintf('%s',str{1})));            Root.appendChild(node);            source_node=Createnode.createElement('source');            Root.appendChild(source_node);            node=Createnode.createElement('database');            node.appendChild(Createnode.createTextNode(sprintf('The VOC2007 Database')));            source_node.appendChild(node);            node=Createnode.createElement('annotation');            node.appendChild(Createnode.createTextNode(sprintf('PASCAL VOC2007')));            source_node.appendChild(node);           node=Createnode.createElement('image');           node.appendChild(Createnode.createTextNode(sprintf('flickr')));           source_node.appendChild(node);           node=Createnode.createElement('flickrid');           node.appendChild(Createnode.createTextNode(sprintf('NULL')));           source_node.appendChild(node);           owner_node=Createnode.createElement('owner');           Root.appendChild(owner_node);           node=Createnode.createElement('flickrid');           node.appendChild(Createnode.createTextNode(sprintf('NULL')));           owner_node.appendChild(node);           node=Createnode.createElement('name');           node.appendChild(Createnode.createTextNode(sprintf('watersink')));           owner_node.appendChild(node);           size_node=Createnode.createElement('size');           Root.appendChild(size_node);          node=Createnode.createElement('width');          node.appendChild(Createnode.createTextNode(sprintf('%s',num2str(w))));          size_node.appendChild(node);          node=Createnode.createElement('height');          node.appendChild(Createnode.createTextNode(sprintf('%s',num2str(h))));          size_node.appendChild(node);         node=Createnode.createElement('depth');         node.appendChild(Createnode.createTextNode(sprintf('%s',num2str(d))));         size_node.appendChild(node);          node=Createnode.createElement('segmented');          node.appendChild(Createnode.createTextNode(sprintf('%s','0')));          Root.appendChild(node);          object_node=Createnode.createElement('object');          Root.appendChild(object_node);          node=Createnode.createElement('name');          node.appendChild(Createnode.createTextNode(sprintf('%s','person')));          object_node.appendChild(node);          node=Createnode.createElement('pose');          node.appendChild(Createnode.createTextNode(sprintf('%s','Unspecified')));          object_node.appendChild(node);          node=Createnode.createElement('truncated');          node.appendChild(Createnode.createTextNode(sprintf('%s','0')));          object_node.appendChild(node);          node=Createnode.createElement('difficult');          node.appendChild(Createnode.createTextNode(sprintf('%s','0')));          object_node.appendChild(node);          bndbox_node=Createnode.createElement('bndbox');          object_node.appendChild(bndbox_node);         node=Createnode.createElement('xmin');         node.appendChild(Createnode.createTextNode(sprintf('%s',num2str(str{2}))));         bndbox_node.appendChild(node);         node=Createnode.createElement('ymin');         node.appendChild(Createnode.createTextNode(sprintf('%s',num2str(str{3}))));         bndbox_node.appendChild(node);        node=Createnode.createElement('xmax');        node.appendChild(Createnode.createTextNode(sprintf('%s',num2str(str{4}))));        bndbox_node.appendChild(node);        node=Createnode.createElement('ymax');        node.appendChild(Createnode.createTextNode(sprintf('%s',num2str(str{5}))));        bndbox_node.appendChild(node);       lastname=str{1};        end        %处理最后一行        if feof(fidin)            tempname=lastname;            tempname=strrep(tempname,'.jpg','.xml');            xmlwrite(tempname,Createnode);        endendfclose(fidin);% file=dir(pwd);% for i=1:length(file)%    if length(file(i).name)>=4 && strcmp(file(i).name(end-3:end),'.xml')%     fold=fopen([file(i).name ],'r');%     fnew=fopen([xmlpath_new file(i).name],'w');%     line=1;%     while ~feof(fold)%         tline=fgetl(fold);%         if line==1%            line=2;%            continue;%         end%         expression = '   ';%         replace=char(9);%         newStr=regexprep(tline,expression,replace);%         fprintf(fnew,'%s\n',newStr);%     end%     fprintf('已处理%s\n',file(i).name);%     fclose(fold);%     fclose(fnew);%   delete(file(i).name);%    end% end

3、训练自己的模型
根据自己的训练数据集的特点,有几个类别,自行更改.prototxt
这里给出我使用的网络更改地方。要根据不同网络更改不同的地方。但是,本质都是更改输入输出的大小,适应更改类别种数后的数据。
我所使用的网络:py-faster-rcnn/models/pascal_voc/VGG_CNN_M_1024/faster_rcnn_end2end/
train.prototxt更改内容为以下四个地方:

        layer {        name: 'data'        type: 'Python'        top: 'data'        top: 'rois'        top: 'labels'        top: 'bbox_targets'        top: 'bbox_inside_weights'        top: 'bbox_outside_weights'        python_param {          module: 'roi_data_layer.layer'          layer: 'RoIDataLayer'          param_str: "'num_classes': 2" #按训练集类别改,该值为类别数+1        }      }  
    layer {        name: "cls_score"        type: "InnerProduct"        bottom: "fc7"        top: "cls_score"        param { lr_mult: 1.0 }        param { lr_mult: 2.0 }        inner_product_param {          num_output: 2 #按训练集类别改,该值为类别数+1          weight_filler {            type: "gaussian"            std: 0.01          }          bias_filler {            type: "constant"            value: 0          }        }      }  
    layer {        name: "bbox_pred"        type: "InnerProduct"        bottom: "fc7"        top: "bbox_pred"        param { lr_mult: 1.0 }        param { lr_mult: 2.0 }        inner_product_param {          num_output: 8 #按训练集类别改,该值为(类别数+1)*4          weight_filler {            type: "gaussian"            std: 0.001          }          bias_filler {            type: "constant"            value: 0          }        }      }  
layer {  name: 'roi-data'  type: 'Python'  bottom: 'rpn_rois'  bottom: 'gt_boxes'  top: 'rois'  top: 'labels'  top: 'bbox_targets'  top: 'bbox_inside_weights'  top: 'bbox_outside_weights'  python_param {    module: 'rpn.proposal_target_layer'    layer: 'ProposalTargetLayer'    param_str: "'num_classes': 2"  #类别数+1  }

至此,网络更改完毕。
输入训练指令:

./experiments/scripts/faster_rcnn_end2end.sh 0 VGG_CNN_M_1024 pascal_voc

开始正式训练模型,但是由于python版本兼容性问题,会出现跟种问题。这里只给出部分最常见的错误修改方法。主要参考http://blog.csdn.net/mydear_11000/article/details/70241139

Problem 1

AttributeError: 'module' object has no attributetext_format'

解决方法:在/home/xxx/py-faster-rcnn/lib/fast_rcnn/train.py的头文件导入部分加上 :import google.protobuf.text_format

Problem 2

TypeError: 'numpy.float64' object cannot be interpreted as an index 

这里是因为numpy版本不兼容导致的问题,最好的解决办法是卸载你的numpy,安装numpy1.11.0。如果你和笔者一样不是服务器的网管,没有权限的话,就只能自己想办法解决了。
修改如下几个地方的code:

1) /home/xxx/py-faster-rcnn/lib/roi_data_layer/minibatch.py

将第26行:fg_rois_per_image = np.round(cfg.TRAIN.FG_FRACTION * rois_per_image)改为:fg_rois_per_image = np.round(cfg.TRAIN.FG_FRACTION * rois_per_image).astype(np.int)

2) /home/xxx/py-faster-rcnn/lib/datasets/ds_utils.py

将第12行:hashes = np.round(boxes * scale).dot(v)改为:hashes = np.round(boxes * scale).dot(v).astype(np.int)

3) /home/xxx/py-faster-rcnn/lib/fast_rcnn/test.py

将第129行: hashes = np.round(blobs['rois'] * cfg.DEDUP_BOXES).dot(v)改为: hashes = np.round(blobs['rois'] * cfg.DEDUP_BOXES).dot(v).astype(np.int)

4) /home/xxx/py-faster-rcnn/lib/rpn/proposal_target_layer.py

将第60行:fg_rois_per_image = np.round(cfg.TRAIN.FG_FRACTION * rois_per_image)改为:fg_rois_per_image = np.round(cfg.TRAIN.FG_FRACTION * rois_per_image).astype(np.int)

Problem3

TypeError: slice indices must be integers or None or have an __index__ method

这里还是因为numpy版本的原因,最好的解决办法还是换numpy版本(见problem2),但同样也有其他的解决办法。
修改 /home/lzx/py-faster-rcnn/lib/rpn/proposal_target_layer.py,转到123行:

for ind in inds:        cls = clss[ind]        start = 4 * cls        end = start + 4        bbox_targets[ind, start:end] = bbox_target_data[ind, 1:]        bbox_inside_weights[ind, start:end] = cfg.TRAIN.BBOX_INSIDE_WEIGHTS    return bbox_targets, bbox_inside_weights

这里的ind,start,end都是 numpy.int 类型,这种类型的数据不能作为索引,所以必须对其进行强制类型转换,转化结果如下:

for ind in inds:        ind = int(ind)        cls = clss[ind]        start = int(4 * cls)        end = int(start + 4)        bbox_targets[ind, start:end] = bbox_target_data[ind, 1:]        bbox_inside_weights[ind, start:end] = cfg.TRAIN.BBOX_INSIDE_WEIGHTS    return bbox_targets, bbox_inside_weights