Faster-RCNN_TF代码解读16:roi_data_layer/roidb.py

来源:互联网 发布:sap采购订单数据库 编辑:程序博客网 时间:2024/06/05 19:43
# --------------------------------------------------------# Fast R-CNN# Copyright (c) 2015 Microsoft# Licensed under The MIT License [see LICENSE for details]# Written by Ross Girshick# --------------------------------------------------------"""Transform a roidb into a trainable roidb by adding a bunch of metadata."""import numpy as npfrom fast_rcnn.config import cfgfrom fast_rcnn.bbox_transform import bbox_transformfrom utils.cython_bbox import bbox_overlapsimport PILdef prepare_roidb(imdb):    """Enrich the imdb's roidb by adding some derived quantities that    are useful for training. This function precomputes the maximum    overlap, taken over ground-truth boxes, between each ROI and    each ground-truth box. The class with maximum overlap is also    recorded.    """    sizes = [PIL.Image.open(imdb.image_path_at(i)).size             for i in xrange(imdb.num_images)]    roidb = imdb.roidb    #对所有的iamge(包含数据增强部分)进行迭代    for i in xrange(len(imdb.image_index)):        #image信息记录图像全路径,width、heigth为图片宽和高        roidb[i]['image'] = imdb.image_path_at(i)        roidb[i]['width'] = sizes[i][0]        roidb[i]['height'] = sizes[i][1]        # need gt_overlaps as a dense array for argmax        #roidb[i]['gt_overlaps']为压缩后的one-hot矩阵,toarray()就为解压缩,复原了one_hot矩阵        gt_overlaps = roidb[i]['gt_overlaps'].toarray()        # max overlap with gt over classes (columns)        #取出最大值        max_overlaps = gt_overlaps.max(axis=1)        #取出最大值对应引索        # gt class that had the max overlap        max_classes = gt_overlaps.argmax(axis=1)        #在roidb列表中的图片信息dict中添加两个信息        roidb[i]['max_classes'] = max_classes        roidb[i]['max_overlaps'] = max_overlaps        # sanity checks        # max overlap of 0 => class should be zero (background)        #最大值为0的为背景类(我在xml文件中没有找到定义background的,遗留问题),结果是找到backgound类为该副图像bboxes的引索        zero_inds = np.where(max_overlaps == 0)[0]        #引入一个异常,没什么作用,就是保证上一步能正确操作        assert all(max_classes[zero_inds] == 0)        # max overlap > 0 => class should not be zero (must be a fg class)        #记录非零类在该副图像中的boxes引索        nonzero_inds = np.where(max_overlaps > 0)[0]        #同样引入一个没什么用的异常        assert all(max_classes[nonzero_inds] != 0)def add_bbox_regression_targets(roidb):    """Add information needed to train bounding-box regressors."""    assert len(roidb) > 0    assert 'max_classes' in roidb[0], 'Did you call prepare_roidb first?'    #图片个数(包括水平翻转)    num_images = len(roidb)    # Infer number of classes from the number of columns in gt_overlaps    #分类数,21,包括背景    num_classes = roidb[0]['gt_overlaps'].shape[1]    #取每个图片的信息    for im_i in xrange(num_images):        rois = roidb[im_i]['boxes']        max_overlaps = roidb[im_i]['max_overlaps']        max_classes = roidb[im_i]['max_classes']        #(标签,dx,dy,dw,dh)        roidb[im_i]['bbox_targets'] = \                _compute_targets(rois, max_overlaps, max_classes)    #False    if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:        # Use fixed / precomputed "means" and "stds" instead of empirical values        means = np.tile(                np.array(cfg.TRAIN.BBOX_NORMALIZE_MEANS), (num_classes, 1))        stds = np.tile(                np.array(cfg.TRAIN.BBOX_NORMALIZE_STDS), (num_classes, 1))    else:        # Compute values needed for means and stds        # var(x) = E(x^2) - E(x)^2        #cfg.EPS 为1e-14        class_counts = np.zeros((num_classes, 1)) + cfg.EPS        sums = np.zeros((num_classes, 4))        squared_sums = np.zeros((num_classes, 4))        for im_i in xrange(num_images):            targets = roidb[im_i]['bbox_targets']            for cls in xrange(1, num_classes):                cls_inds = np.where(targets[:, 0] == cls)[0]                if cls_inds.size > 0:                    #记录该图像上有几个当前类的roi                    class_counts[cls] += cls_inds.size                    #取出该类对应引索的targets,四个值纵向相加                    sums[cls, :] += targets[cls_inds, 1:].sum(axis=0)                    #取出该类对应引索的targets,四个值平方和纵向相加                    squared_sums[cls, :] += \                            (targets[cls_inds, 1:] ** 2).sum(axis=0)        #均值        means = sums / class_counts        #方差        stds = np.sqrt(squared_sums / class_counts - means ** 2)    print 'bbox target means:'    print means    print means[1:, :].mean(axis=0) # ignore bg class    print 'bbox target stdevs:'    print stds    print stds[1:, :].mean(axis=0) # ignore bg class    # Normalize targets    #True    #对于不同类,分别进行标准化(dx,dy,dw,dh)    if cfg.TRAIN.BBOX_NORMALIZE_TARGETS:        print "Normalizing targets"        for im_i in xrange(num_images):            targets = roidb[im_i]['bbox_targets']            for cls in xrange(1, num_classes):                cls_inds = np.where(targets[:, 0] == cls)[0]                roidb[im_i]['bbox_targets'][cls_inds, 1:] -= means[cls, :]                roidb[im_i]['bbox_targets'][cls_inds, 1:] /= stds[cls, :]    else:        print "NOT normalizing targets"    # These values will be needed for making predictions    # (the predicts will need to be unnormalized and uncentered)    return means.ravel(), stds.ravel()#如果传进来的是rois(bboxes), max_overlaps, max_classes,则这里全是GTdef _compute_targets(rois, overlaps, labels):    """Compute bounding-box regression targets for an image."""    #这个函数主要是计算一副图像bboxes回归信息,返回(rois.shape[0], 5)    # Indices of ground-truth ROIs    #那一行有1,len(gt_inds)表示所有行一共有几个1    gt_inds = np.where(overlaps == 1)[0]    #GT情况:这种情况不存在,roidb已经筛选出没有任何fg与bg的图片,只要有一个,就会存在1,len(gt_inds)就不为0    if len(gt_inds) == 0:        # Bail if the image has no ground-truth ROIs        return np.zeros((rois.shape[0], 5), dtype=np.float32)    # Indices of examples for which we try to make predictions    #cfg.TRAIN.BBOX_THRESH为0.5    #情况为GT,则全满足    ex_inds = np.where(overlaps >= cfg.TRAIN.BBOX_THRESH)[0]    # Get IoU overlap between each ex ROI and gt ROI    #建立(len(ex_inds),len(gt_inds))大小的矩阵,内容为iou    ex_gt_overlaps = bbox_overlaps(        np.ascontiguousarray(rois[ex_inds, :], dtype=np.float),        np.ascontiguousarray(rois[gt_inds, :], dtype=np.float))    # Find which gt ROI each ex ROI has max overlap with:    # this will be the ex ROI's gt target    #找到与该ex_roi最佳匹配GT    gt_assignment = ex_gt_overlaps.argmax(axis=1)    #取出gt_rois与ex_rois(bboxes)    gt_rois = rois[gt_inds[gt_assignment], :]    ex_rois = rois[ex_inds, :]    #targets:(标签,dx,dy,dw,dh)    targets = np.zeros((rois.shape[0], 5), dtype=np.float32)    #gt情况:就是max_classes,ex_inds就是全部引索,因为GT情况上面的条件全满足    targets[ex_inds, 0] = labels[ex_inds]    targets[ex_inds, 1:] = bbox_transform(ex_rois, gt_rois)    return targets
原创粉丝点击