【深度学习:目标检测】 py-faster-rcnn标注FDDB人脸便于其在FDDB上进行测试

来源:互联网 发布:红外透视镜软件下载 编辑:程序博客网 时间:2024/05/22 22:39

转载:http://blog.csdn.net/xzzppp/article/details/52071460

本程序是在py-faster-rcnn/tools/demo.py的基础上进行修改的

程序功能:用训练好的caffemodel,对FDDB人脸进行标注,便于其在FDDB上进行测试

[python] view plain copy
  1. <span style="font-size:24px;">#!/usr/bin/env python   
  2.   
  3. # --------------------------------------------------------  
  4. # Faster R-CNN  
  5. # Copyright (c) 2015 Microsoft  
  6. # Licensed under The MIT License [see LICENSE for details]  
  7. # Written by Ross Girshick  
  8. # --------------------------------------------------------  
  9.   
  10. """ 
  11. Demo script showing detections in sample images. 
  12.  
  13. See README.md for installation instructions before running. 
  14. """  
  15.   
  16. import _init_paths  
  17. from fast_rcnn.config import cfg  
  18. from fast_rcnn.test import im_detect  
  19. from fast_rcnn.nms_wrapper import nms  
  20. from utils.timer import Timer  
  21. import matplotlib.pyplot as plt  
  22. import numpy as np  
  23. import scipy.io as sio  
  24. import caffe, os, sys, cv2  
  25. import argparse  
  26.   
  27. #CLASSES = ('__background__',  #背景 + 类  
  28. #           'aeroplane', 'bicycle', 'bird', 'boat',  
  29. #           'bottle', 'bus', 'car', 'cat', 'chair',  
  30. #           'cow', 'diningtable', 'dog', 'horse',  
  31. #           'motorbike', 'person', 'pottedplant',  
  32. #           'sheep', 'sofa', 'train', 'tvmonitor')  
  33.   
  34. CLASSES = ('__background__','face'#只有一类:face  
  35.   
  36. NETS = {'vgg16': ('VGG16',  
  37.                   'VGG16_faster_rcnn_final.caffemodel'),  
  38.         'myvgg': ('VGG_CNN_M_1024',  
  39.                   'VGG_CNN_M_1024_faster_rcnn_final.caffemodel'),  
  40.         'zf': ('ZF',  
  41.                   'ZF_faster_rcnn_final.caffemodel'),  
  42.         'myzf': ('ZF',  
  43.                   'zf_rpn_stage1_iter_80000.caffemodel'),  
  44. }  
  45.   
  46.   
  47. def vis_detections(im, class_name, dets, thresh=0.5):  
  48.     """Draw detected bounding boxes."""  
  49.     inds = np.where(dets[:, -1] >= thresh)[0]  
  50.     if len(inds) == 0:  
  51.         return  
  52.   
  53.     write_file.write(str(len(inds)) + '\n'#add by zhipeng  
  54.     im = im[:, :, (210)]  
  55.     #fig, ax = plt.subplots(figsize=(12, 12))  
  56.     #ax.imshow(im, aspect='equal')  
  57.     for i in inds:  
  58.         bbox = dets[i, :4]  
  59.         score = dets[i, -1]  
  60.   
  61.         ##########   add by zhipeng for write rectange to txt   ########  
  62.         write_file.write( "{} {} {} {} {}\n".format(str(bbox[0]), str(bbox[1]),  
  63.                                                         str(bbox[2] - bbox[0]),  
  64.                                                         str(bbox[3] - bbox[1]),  
  65.                                                         str(score)))  
  66.         #print "zhipeng, bbox:", bbox, "score:",score  
  67.         ##########   add by zhipeng for write rectange to txt   ########  
  68.   
  69.         '''''ax.add_patch( 
  70.             plt.Rectangle((bbox[0], bbox[1]), 
  71.                           bbox[2] - bbox[0], 
  72.                           bbox[3] - bbox[1], fill=False, 
  73.                           edgecolor='red', linewidth=3.5) 
  74.             ) 
  75.         ax.text(bbox[0], bbox[1] - 2, 
  76.                 '{:s} {:.3f}'.format(class_name, score), 
  77.                 bbox=dict(facecolor='blue', alpha=0.5), 
  78.                 fontsize=14, color='white') 
  79.  
  80.     ax.set_title(('{} detections with ' 
  81.                   'p({} | box) >= {:.1f}').format(class_name, class_name, 
  82.                                                   thresh), 
  83.                   fontsize=14) 
  84.     plt.axis('off') 
  85.     plt.tight_layout() 
  86.     plt.draw()'''  
  87.   
  88. def demo(net, image_name):  
  89.     """Detect object classes in an image using pre-computed object proposals."""  
  90.   
  91.     # Load the demo image  
  92.     #im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)  
  93.     im = cv2.imread(image_name)  
  94.   
  95.     # Detect all object classes and regress object bounds  
  96.     timer = Timer()  
  97.     timer.tic()  
  98.     scores, boxes = im_detect(net, im)  
  99.     timer.toc()  
  100.     print ('Detection took {:.3f}s for '  
  101.            '{:d} object proposals').format(timer.total_time, boxes.shape[0])  
  102.   
  103.     # Visualize detections for each class  
  104.     CONF_THRESH = 0.8  
  105.     NMS_THRESH = 0.3  
  106.     for cls_ind, cls in enumerate(CLASSES[1:]):  
  107.         cls_ind += 1 # because we skipped background  
  108.         cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]  
  109.         cls_scores = scores[:, cls_ind]  
  110.         dets = np.hstack((cls_boxes,  
  111.                           cls_scores[:, np.newaxis])).astype(np.float32)  
  112.         keep = nms(dets, NMS_THRESH)  
  113.         dets = dets[keep, :]  
  114.         vis_detections(im, cls, dets, thresh=CONF_THRESH)  
  115.   
  116. def parse_args():  
  117.     """Parse input arguments."""  
  118.     parser = argparse.ArgumentParser(description='Faster R-CNN demo')  
  119.     parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',  
  120.                         default=0, type=int)  
  121.     parser.add_argument('--cpu', dest='cpu_mode',  
  122.                         help='Use CPU mode (overrides --gpu)',  
  123.                         action='store_true')  
  124.     parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',  
  125.                         choices=NETS.keys(), default='vgg16')  
  126.   
  127.     args = parser.parse_args()  
  128.   
  129.     return args  
  130.   
  131. if __name__ == '__main__':  
  132.     cfg.TEST.HAS_RPN = True  # Use RPN for proposals  
  133.   
  134.     args = parse_args()  
  135.   
  136.     prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],  
  137.                             'faster_rcnn_alt_opt''faster_rcnn_test.pt')  
  138.     caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',  
  139.                               NETS[args.demo_net][1])  
  140.   
  141.     if not os.path.isfile(caffemodel):  
  142.         raise IOError(('{:s} not found.\nDid you run ./data/script/'  
  143.                        'fetch_faster_rcnn_models.sh?').format(caffemodel))  
  144.   
  145.     if args.cpu_mode:  
  146.         caffe.set_mode_cpu()  
  147.     else:  
  148.         caffe.set_mode_gpu()  
  149.         caffe.set_device(args.gpu_id)  
  150.         cfg.GPU_ID = args.gpu_id  
  151.     net = caffe.Net(prototxt, caffemodel, caffe.TEST)  
  152.   
  153.     print '\n\nLoaded network {:s}'.format(caffemodel)  
  154.   
  155.     # Warmup on a dummy image  
  156.     im = 128 * np.ones((3005003), dtype=np.uint8)  
  157.     for i in xrange(2):  
  158.         _, _= im_detect(net, im)  
  159.   
  160.     '''''im_names = ['000456.jpg', '000542.jpg', '001150.jpg', 
  161.                 '001763.jpg', '004545.jpg']'''  
  162.   
  163.     ##########   add by zhipeng for write rectange to txt   ########  
  164.     #write_file_name = '/home/xiao/code/py-faster-rcnn-master/py-faster-rcnn/tools/detections/out.txt'  
  165.     #write_file = open(write_file_name, "w")  
  166.     ##########   add by zhipeng for write rectange to txt   ########  
  167.   
  168.     for current_file in range(1,11):      #orginal range(1, 11)  
  169.   
  170.         print 'Processing file ' + str(current_file) + ' ...'  
  171.   
  172.         read_file_name = '/home/xiao/code/py-faster-rcnn-master/py-faster-rcnn/tools/FDDB-fold/FDDB-fold-' + str(current_file).zfill(2) + '.txt'  
  173.         write_file_name = '/home/xiao/code/py-faster-rcnn-master/py-faster-rcnn/tools/detections/fold-' + str(current_file).zfill(2) + '-out.txt'  
  174.         write_file = open(write_file_name, "w")  
  175.   
  176.         with open(read_file_name, "r") as ins:  
  177.             array = []  
  178.             for line in ins:  
  179.                 array.append(line)      # list of strings  
  180.   
  181.         number_of_images = len(array)  
  182.   
  183.         for current_image in range(number_of_images):  
  184.             if current_image % 10 == 0:  
  185.                 print 'Processing image : ' + str(current_image)  
  186.             # load image and convert to gray  
  187.             read_img_name = '/home/xiao/code/py-faster-rcnn-master/py-faster-rcnn/data/FDDB/originalPics/' + array[current_image].rstrip() + '.jpg'  
  188.             write_file.write(array[current_image]) #add by zhipeng  
  189.             demo(net, read_img_name)  
  190.   
  191.         write_file.close()  
  192.   
  193.     '''''for im_name in im_names: 
  194.         print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' 
  195.         print 'Demo for data/demo/{}'.format(im_name) 
  196.         write_file.write(im_name + '\n') #add by zhipeng 
  197.         demo(net, im_name)'''  
  198.   
  199.     #write_file.close()  # add by zhipeng,close file  
  200.     plt.show()  
  201. </span>  

0 0