ChainerCV下用自己的数据集训练Faster RCNN

来源:互联网 发布:尚学堂 java视频 编辑:程序博客网 时间:2024/05/22 12:38

一、ChainerCV

        ChainerCV,一个用于计算机视觉的深度学习实用库。这个库旨在通过 Chainer 简易化计算机视觉的训练和深度学习模型应用的过程。它包含计算机视觉模型的高质量实现,以及开展计算机视觉研究的必备工具集。当前,ChainerCV 提供了目标检测和语义分割模型(Faster R-CNN、SSD 和 SegNet)的实现。

        https://github.com/chainer/chainercv

二、修改原始代码

        1,chainercv/datasets/voc/voc_utils.py中

(1)更改voc_bbox_label_names,改为自己数据库的类标;增加oc_semantic_segmentation_label_colors,和类标相对应

        2,chainercv/datasets/voc/voc_bbox_dataset.py中,

(1)第8行from chainercv.datasets.voc import voc_utils 改为 import voc_utils;

(2)在get_example中,将obj.find('name').text.lower().strip()改为obj.find('name').text.strip(),因为.lower()为把所有的namee变为小写,而我们的数据集的label name是区分大小写的

        3,examples/faster_rcnn/train.py中,

(1)文件开头加入 import sys,sys.path.append(r'/home/wang/Development/chainercv/chainercv/datasets/voc');并将

       from chainercv.datasets import voc_bbox_label_names

       from chainercv.datasets import VOCBboxDataset

改为from voc_utils import voc_bbox_label_names

       from voc_bbox_dataset import VOCBboxDataset

这样导入的label name为我们修改后的,而不是原始VOC的;

(2)并在main()中

train_data = VOCBboxDataset(split='trainval', year='2007')加入路径变为

train_data = VOCBboxDataset('/home/wang/Development/VOCdevkit2007/VOC2007',split='trainval', year='2007'),

同样的test_data = VOCBboxDataset(data_dir='/home/wang/Development/VOCdevkit2007/VOC2007',split='test', year='2007',use_difficult=True, return_difficult=True);

这里最重要,加了自己数据库的路径后,就直接导入自己的数据库,若没有路径,会在VOC链接下载VOC数据库,就会导致错误。

         4,examples/faster_rcnn/demo.py中,

(1) 文件开头加入 import sys,sys.path.append(r'/home/wang/Development/chainercv/chainercv/datasets/voc');

(2) 并将 from chainercv.datasets import voc_bbox_label_names 改为 from voc_utils import voc_bbox_label_names。和train的原因一样

三,训练 

python train.py --iteration 70000  --gpu 0

四,测试

python demo.py --gpu 1 --pretrained_model /home/wang/Development/chainercv/examples/faster_rcnn/result/snapshot_model.npz /home/wang/Development/VOCdevkit2007/VOC2007/JPEGImages/000102.jpg

注:在换了电脑后训练和测试出现问题

这是因为版本不兼容导致GPU出问题,用命令sudo pip install cupy==2.0.0将cupy降为2.0.0,问题解决。此时版本信息为:


阅读全文
0 0