faster rcnn demo.py:在一个窗口显示所有类别标注

来源:互联网 发布:it helpdesk 工作内容 编辑:程序博客网 时间:2024/06/05 02:30

转载地址:http://blog.csdn.net/10km/article/details/68926498

方便自己随时看。



faster rcnn 的demo.py运行时,对于同一个图像,每个类别显示一个窗口,看起来不太方便,顺便小改一下,让一幅图像中检测到的所有类别物体都在一个窗口下标注,就方便多了。

代码改动也不复杂,就是把vis_detections函数中for循环前后三行代码移动到demo函数的for循环前后。
完整代码如下(顺便把标注框的线宽改成了1,以前是3.5太粗了,不好看):
py-faster-rcnn/tools/demo.py (注意代码中本人添加的中文注释)
#!/usr/bin/env python#coding=utf8# 因为代码中我加了中文注释,所以 上面这行用于指定编码 ,否则python代码执行会报错 # --------------------------------------------------------# Faster R-CNN# Copyright (c) 2015 Microsoft# Licensed under The MIT License [see LICENSE for details]# Written by Ross Girshick# --------------------------------------------------------"""Demo script showing detections in sample images.See README.md for installation instructions before running."""import _init_pathsfrom fast_rcnn.config import cfgfrom fast_rcnn.test import im_detectfrom fast_rcnn.nms_wrapper import nmsfrom utils.timer import Timerimport matplotlib.pyplot as pltimport numpy as npimport scipy.io as sioimport caffe, os, sys, cv2import argparseCLASSES = ('__background__',           'aeroplane', 'bicycle', 'bird', 'boat',           'bottle', 'bus', 'car', 'cat', 'chair',           'cow', 'diningtable', 'dog', 'horse',           'motorbike', 'person', 'pottedplant',           'sheep', 'sofa', 'train', 'tvmonitor')NETS = {'vgg16': ('VGG16',                  'VGG16_faster_rcnn_final.caffemodel'),        'zf': ('ZF',                  'ZF_faster_rcnn_final.caffemodel')}#增加ax参数def vis_detections(im, class_name, dets, ax, thresh=0.5):    """Draw detected bounding boxes."""    inds = np.where(dets[:, -1] >= thresh)[0]    if len(inds) == 0:        return# 删除这三行#     im = im[:, :, (2, 1, 0)]#     fig, ax = plt.subplots(figsize=(12, 12))#     ax.imshow(im, aspect='equal')    for i in inds:        bbox = dets[i, :4]        score = dets[i, -1]        ax.add_patch(            plt.Rectangle((bbox[0], bbox[1]),                          bbox[2] - bbox[0],                          bbox[3] - bbox[1], fill=False,                          edgecolor='red', linewidth=1) # 矩形线宽从3.5改为1            )        ax.text(bbox[0], bbox[1] - 2,                '{:s} {:.3f}'.format(class_name, score),                bbox=dict(facecolor='blue', alpha=0.5),                fontsize=14, color='white')    ax.set_title(('{} detections with '                  'p({} | box) >= {:.1f}').format(class_name, class_name,                                                  thresh),                  fontsize=14)# 删除这三行#     plt.axis('off')#     plt.tight_layout()#     plt.draw()def demo(net, image_name):    """Detect object classes in an image using pre-computed object proposals."""    # Load the demo image    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)    im = cv2.imread(im_file)    # Detect all object classes and regress object bounds    timer = Timer()    timer.tic()    scores, boxes = im_detect(net, im)    timer.toc()    print ('Detection took {:.3f}s for '           '{:d} object proposals').format(timer.total_time, boxes.shape[0])    # Visualize detections for each class    CONF_THRESH = 0.8    NMS_THRESH = 0.3    # 将vis_detections 函数中for 循环之前的3行代码移动到这里    im = im[:, :, (2, 1, 0)]    fig,ax = plt.subplots(figsize=(12, 12))    ax.imshow(im, aspect='equal')     for cls_ind, cls in enumerate(CLASSES[1:]):        cls_ind += 1 # because we skipped background        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]        cls_scores = scores[:, cls_ind]        dets = np.hstack((cls_boxes,                          cls_scores[:, np.newaxis])).astype(np.float32)        keep = nms(dets, NMS_THRESH)        dets = dets[keep, :]        #将ax做为参数传入vis_detections        vis_detections(im, cls, dets, ax,thresh=CONF_THRESH)    # 将vis_detections 函数中for 循环之后的3行代码移动到这里    plt.axis('off')    plt.tight_layout()    plt.draw()def parse_args():    """Parse input arguments."""    parser = argparse.ArgumentParser(description='Faster R-CNN demo')    parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',                        default=0, type=int)    parser.add_argument('--cpu', dest='cpu_mode',                        help='Use CPU mode (overrides --gpu)',                        action='store_true')    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',                        choices=NETS.keys(), default='vgg16')    args = parser.parse_args()    return argsif __name__ == '__main__':    cfg.TEST.HAS_RPN = True  # Use RPN for proposals    args = parse_args()    prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],                            'faster_rcnn_alt_opt', 'faster_rcnn_test.pt')    caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',                              NETS[args.demo_net][1])    if not os.path.isfile(caffemodel):        raise IOError(('{:s} not found.\nDid you run ./data/script/'                       'fetch_faster_rcnn_models.sh?').format(caffemodel))    if args.cpu_mode:        caffe.set_mode_cpu()    else:        caffe.set_mode_gpu()        caffe.set_device(args.gpu_id)        cfg.GPU_ID = args.gpu_id    net = caffe.Net(prototxt, caffemodel, caffe.TEST)    print '\n\nLoaded network {:s}'.format(caffemodel)    # Warmup on a dummy image    im = 128 * np.ones((300, 500, 3), dtype=np.uint8)    for i in xrange(2):        _, _= im_detect(net, im)    im_names = ['000456.jpg', '000542.jpg', '001150.jpg',                '001763.jpg', '004545.jpg']    for im_name in im_names:        print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'        print 'Demo for data/demo/{}'.format(im_name)        demo(net, im_name)    plt.show()

原创粉丝点击