This demo shows the method proposed in "Zhou, Bolei, et al. "Learning Deep Features for Discriminative Localization." arXiv preprint arXiv:1512.04150 (2015)".

The proposed method can automatically localize the discriminative regions in an image using global average pooling (GAP) in CNNs.

You can download the pretrained Inception-V3 network from here. Other networks with similar structure(use global average pooling after the last conv feature map) should also work.

In [1]:
# -*- coding: UTF-8 –*-import matplotlib.pyplot as plt%matplotlib inlinefrom IPython import displayimport osROOT_DIR = '.'import syssys.path.insert(0, os.path.join(ROOT_DIR, 'lib'))import cv2import numpy as npimport mxnet as mximport matplotlib.pyplot as plt

Set the image you want to test and the classification network you want to use. Notice "conv_layer" should be the last conv layer before the average pooling layer.

In [2]:
im_file = os.path.join(ROOT_DIR, 'sample_pics/barbell.jpg')synset_file = os.path.join(ROOT_DIR, 'models/inception-v3/synset.txt')net_json = os.path.join(ROOT_DIR, 'models/inception-v3/Inception-7-symbol.json')conv_layer = 'ch_concat_mixed_10_chconcat_output'prob_layer = 'softmax_output'arg_fc = 'fc1'params = os.path.join(ROOT_DIR, 'models/inception-v3/Inception-7-0001.params')mean = (128, 128, 128)raw_scale = 1.0input_scale = 1.0/128width = 299height = 299resize_size = 340top_n = 5ctx = mx.cpu(1)

Load the label name of each class.

In [3]:
synset = [l.strip() for l in open(synset_file).readlines()]

Build network symbol and load network parameters.

In [4]:
symbol = mx.sym.load(net_json)internals = symbol.get_internals()symbol = mx.sym.Group([internals[prob_layer], internals[conv_layer]])save_dict = mx.nd.load(params)arg_params = {}aux_params = {}for k, v in save_dict.items():    l2_tp, name = k.split(':', 1)    if l2_tp == 'arg':        arg_params[name] = v    if l2_tp == 'aux':        aux_params[name] = vmod = mx.model.FeedForward(symbol,                           arg_params=arg_params,                           aux_params=aux_params,                           ctx=ctx,                           allow_extra_params=False,                           numpy_batch_size=1)
  1. Read the weight of the fc layer in softmax classification layer. Bias can be neglected since it does not really affect the result.
  2. Load the image you want to test and convert it from BGR to RGB(opencv use BGR by default).
In [5]:
weight_fc = arg_params[arg_fc+'_weight'].asnumpy()# bias_fc = arg_params[arg_fc+'_bias'].asnumpy()im = cv2.imread(im_file)rgb = cv2.cvtColor(cv2.resize(im, (width, height)), cv2.COLOR_BGR2RGB)

Feed the image data to our network and get the outputs.

We select the top 5 classes for visualization by default.

In [6]:
def im2blob(im, width, height, mean=None, input_scale=1.0, raw_scale=1.0, swap_channel=True):    blob = cv2.resize(im, (height, width)).astype(np.float32)    blob = blob.reshape((1, height, width, 3))    #  from nhwc to nchw    blob = np.swapaxes(blob, 2, 3)    blob = np.swapaxes(blob, 1, 2)    if swap_channel:        blob[:, [0, 2], :, :] = blob[:, [2, 0], :, :]    if raw_scale != 1.0:        blob *= raw_scale    if isinstance(mean, np.ndarray):        blob -= mean    elif isinstance(mean, tuple) or isinstance(mean, list):        blob[:, 0, :, :] -= mean[0]        blob[:, 1, :, :] -= mean[1]        blob[:, 2, :, :] -= mean[2]    elif mean is None:        pass    else:        raise TypeError, 'mean should be either a tuple or a np.ndarray'    if input_scale != 1.0:        blob *= input_scale    return blob
In [7]:
blob = im2blob(im, width, height, mean=mean, swap_channel=True, raw_scale=raw_scale, input_scale=input_scale)outputs = mod.predict(blob)score = outputs[0][0]conv_fm = outputs[1][0]score_sort = -np.sort(-score)[:top_n]inds_sort = np.argsort(-score)[:top_n]

Localize the discriminative regions by analysing the class's response in the network's last conv feature map.

In [8]:
def get_cam(conv_feat_map, weight_fc):    assert len(weight_fc.shape) == 2    if len(conv_feat_map.shape) == 3:        C, H, W = conv_feat_map.shape        assert weight_fc.shape[1] == C        detection_map =, H*W))        detection_map = detection_map.reshape(-1, H, W)    elif len(conv_feat_map.shape) == 4:        N, C, H, W = conv_feat_map.shape        assert weight_fc.shape[1] == C        M = weight_fc.shape[0]        detection_map = np.zeros((N, M, H, W))        for i in xrange(N):            tmp_detection_map =[i].reshape(C, H*W))            detection_map[i, :, :, :] = tmp_detection_map.reshape(-1, H, W)    return detection_map
In [9]:
plt.figure(figsize=(18, 6))plt.subplot(1, 1+top_n, 1)plt.imshow(rgb)cam = get_cam(conv_fm, weight_fc[inds_sort, :])for k in xrange(top_n):    detection_map = np.squeeze(cam.astype(np.float32)[k, :, :])    heat_map = cv2.resize(detection_map, (width, height))    max_response = detection_map.mean()    heat_map /= heat_map.max()    im_show = rgb.astype(np.float32)/255*0.3 +[:, :, :3]*0.7    plt.subplot(1, 1+top_n, k+2)    plt.imshow(im_show)    print 'Top %d: %s(%.6f), max_response=%.4f' % (k+1, synset[inds_sort[k]], score_sort[k], max_response)
Top 1: n02790996 barbell(0.977569), max_response=11.4226Top 2: n03255030 dumbbell(0.011445), max_response=7.0496Top 3: n04487394 trombone(0.000077), max_response=1.8076Top 4: n03535780 horizontal bar, high bar(0.000060), max_response=1.7168Top 5: n03400231 frying pan, frypan, skillet(0.000046), max_response=1.0619