py-faster-rcnn训练自己的数据

来源:互联网 发布:mac开机黑屏有进度条 编辑:程序博客网 时间:2024/04/30 03:42

本人作为初入深度学习的小白,写这篇博客纯属为了记录自己的成长过程,把自己踏过的坑和大家分享一下,也请各位大牛不吝指正。我自己做实验时参考了samylee的文章,博主非常热心,有问题也可以咨询他。

1.实验目的

本人刚刚接触深度学习的时候正值CCF大数据比赛,赛题是检测图片中的交通标志(驭势科技)傻不拉几的用selective search选取proposals,resize224*224丢到alexnet里结果一大把莫名奇妙的框,后来才知道自己多么的年轻。比赛数据一直保存至今,再加上自己对detection有了初步的了解,硬件问题也得到了解决,变重新开始做实验。

2.前提准备

2.1最重要的是py-faster-rcnn的配置与安装,由于本人实验做的不多,以下所有的内容只对有GPU的朋友适用。

2.2下载VOC2007的数据,这里提供一个网盘地址,下载后解压到py-faster-rcnn\data下

2.3下载pre-train的model,这里提供一个网盘地址,下载后放到py-faster-rcnn\data下,完成后的效果如下


3.数据准备

3.1这个地方就是我卡住的地方,大部分blog就是直接说用你的数据替换掉原始的voc数据就好。。。(小白黑人问号脸。。。)在路径/home/bitss/py-faster-rcnn/data/VOCdevkit2007/VOC2007中共有三个有用的文件夹,分别是Annotations,ImageSets,JPEGImages。JPEGImages存储所有的图片,这个文件夹最好搞定。

3.2接着是Annotations,这里面是用xml标注的文件,真的是搞死我了,我从没用过xml,这东西怎么搞,而且如果一个图片里有多个类别或者一个类别存在好几个怎么办?(我说的情况是数据已经有相关的标定,但是不是xml形式的。如果你压根没标定,github上搜索LabelImg,可以手工标定自动生成xml文件)小白这里用python像写文本文件那样写了一个脚本,代码如下(相关路径自己改,代码很蠢,但是能用啊):

#encoding:utf-8import ospath_open='/home/bitss/py-faster-rcnn/data/VOCdevkit2007/VOC2007/txt'path_write='/home/bitss/py-faster-rcnn/data/VOCdevkit2007/VOC2007/Annotations'for files in os.walk(path_open):    for file in files[2]:        f1 = open(path_open+'/'+file)        f2 = open(path_write+'/'+file.replace('txt','xml'),'w')        f2.write('<annotation>\n')        f2.write('        <folder>VOC2007</folder>\n')        f2.write('        <filename>'+file.replace('txt','jpg')+'</filename>\n')        f2.write('        <size>\n')        f2.write('                <width>1280</width>\n')        f2.write('                <height>720</height>\n')        f2.write('                <depth>3</depth>\n')        f2.write('        </size>\n')        f2.write('        <segmented>0</segmented>\n')        for line in f1:            f2.write('        <object>\n')            f2.write('                <name>'+line.split(' ')[1]+'</name>\n')            f2.write('                <pose>Unspecified</pose>\n')            f2.write('                <truncated>0</truncated>\n')            f2.write('                <difficult>0</difficult>\n')            f2.write('                <bndbox>\n')            f2.write('                        <xmin>'+line.split(' ')[2]+'</xmin>\n')            f2.write('                        <ymin>' + line.split(' ')[3] + '</ymin>\n')            f2.write('                        <xmax>' + line.split(' ')[4] + '</xmax>\n')            f2.write('                        <ymax>' + line.split(' ')[5].replace('/r/n','') + '</ymax>\n')            f2.write('                </bndbox>\n')            f2.write('        </object>\n')        f2.write('</annotation>\n')
其中的txt文件夹里的txt文件的形式如下(图片名,类别名,x_min,y_min,x_max,y_max)就比赛数据而言,给了一个总的csv文件标注每张图片中的交通标志文件位置,这些txt是我已经整理好的,这里不再贴出代码,毕竟应该不会用到:

3.3最后是ImageSets,该路径下分别有main、layout和segmentation,里边存储train.txt,trainval.txt,val.txt,test.txt。如何制作这些txt,这里借liumaolincycle的代码(必须先替换掉Annotations中的内容,不要打乱顺序):

%writetxt.mfile = dir('Annotations');len = length(file)-2;num_trainval=sort(randperm(len, floor(9*len/10)));%trainval集占所有数据的9/10,可以根据需要设置num_train=sort(num_trainval(randperm(length(num_trainval), floor(5*length(num_trainval)/6))));%train集占trainval集的5/6,可以根据需要设置num_val=setdiff(num_trainval,num_train);%trainval集剩下的作为val集num_test=setdiff(1:len,num_trainval);%所有数据中剩下的作为test集path = 'ImageSets\Main\';fid=fopen(strcat(path, 'trainval.txt'),'a+');for i=1:length(num_trainval)    s = sprintf('%s',file(num_trainval(i)+2).name);    fprintf(fid,[s(1:length(s)-4) '\n']);endfclose(fid);fid=fopen(strcat(path, 'train.txt'),'a+');for i=1:length(num_train)    s = sprintf('%s',file(num_train(i)+2).name);    fprintf(fid,[s(1:length(s)-4) '\n']);endfclose(fid);fid=fopen(strcat(path, 'val.txt'),'a+');for i=1:length(num_val)    s = sprintf('%s',file(num_val(i)+2).name);    fprintf(fid,[s(1:length(s)-4) '\n']);endfclose(fid);fid=fopen(strcat(path, 'test.txt'),'a+');for i=1:length(num_test)    s = sprintf('%s',file(num_test(i)+2).name);    fprintf(fid,[s(1:length(s)-4) '\n']);endfclose(fid);
4.修改文件以及训练

这里请参考之前samylee的文章,修改几处参数并训练

5.中途可能出现的问题

5.1这里assert(boxes[:, 2] >= boxes[:, 0]).all()可能出现AssertionError,具体解决办法参考:

    http://blog.csdn.net/xzzppp/article/details/52036794

5.2TypeError numpy.float64 object cannot be interpreted as an index

sudo pip uninstall numpy(我是因为下载tensorflow的时候下了一个numpy1.12,与py-fatser-rcnn不兼容,报错本身因为多个numpy库,系统不知道用哪个)

6.运行demo

还是一样,参考samylee的文章

最后大家注意自己用的模型,用ZF就要去ZF里的文件夹修改相关内容,VGG同样如此。

最后祝愿大家学习、工作顺利,也请各位大神不吝指教。

0 0
原创粉丝点击