class_active_maps
来源:互联网 发布:opencv python之间联系 编辑:程序博客网 时间:2024/06/14 12:08
This demo shows the method proposed in "Zhou, Bolei, et al. "Learning Deep Features for Discriminative Localization." arXiv preprint arXiv:1512.04150 (2015)".
The proposed method can automatically localize the discriminative regions in an image using global average pooling (GAP) in CNNs.
You can download the pretrained Inception-V3 network from here. Other networks with similar structure(use global average pooling after the last conv feature map) should also work.
# -*- coding: UTF-8 –*-import matplotlib.pyplot as plt%matplotlib inlinefrom IPython import displayimport osROOT_DIR = '.'import syssys.path.insert(0, os.path.join(ROOT_DIR, 'lib'))import cv2import numpy as npimport mxnet as mximport matplotlib.pyplot as plt
Set the image you want to test and the classification network you want to use. Notice "conv_layer" should be the last conv layer before the average pooling layer.
im_file = os.path.join(ROOT_DIR, 'sample_pics/barbell.jpg')synset_file = os.path.join(ROOT_DIR, 'models/inception-v3/synset.txt')net_json = os.path.join(ROOT_DIR, 'models/inception-v3/Inception-7-symbol.json')conv_layer = 'ch_concat_mixed_10_chconcat_output'prob_layer = 'softmax_output'arg_fc = 'fc1'params = os.path.join(ROOT_DIR, 'models/inception-v3/Inception-7-0001.params')mean = (128, 128, 128)raw_scale = 1.0input_scale = 1.0/128width = 299height = 299resize_size = 340top_n = 5ctx = mx.cpu(1)
Load the label name of each class.
synset = [l.strip() for l in open(synset_file).readlines()]
Build network symbol and load network parameters.
symbol = mx.sym.load(net_json)internals = symbol.get_internals()symbol = mx.sym.Group([internals[prob_layer], internals[conv_layer]])save_dict = mx.nd.load(params)arg_params = {}aux_params = {}for k, v in save_dict.items(): l2_tp, name = k.split(':', 1) if l2_tp == 'arg': arg_params[name] = v if l2_tp == 'aux': aux_params[name] = vmod = mx.model.FeedForward(symbol, arg_params=arg_params, aux_params=aux_params, ctx=ctx, allow_extra_params=False, numpy_batch_size=1)
- Read the weight of the fc layer in softmax classification layer. Bias can be neglected since it does not really affect the result.
- Load the image you want to test and convert it from BGR to RGB(opencv use BGR by default).
weight_fc = arg_params[arg_fc+'_weight'].asnumpy()# bias_fc = arg_params[arg_fc+'_bias'].asnumpy()im = cv2.imread(im_file)rgb = cv2.cvtColor(cv2.resize(im, (width, height)), cv2.COLOR_BGR2RGB)
Feed the image data to our network and get the outputs.
We select the top 5 classes for visualization by default.
def im2blob(im, width, height, mean=None, input_scale=1.0, raw_scale=1.0, swap_channel=True): blob = cv2.resize(im, (height, width)).astype(np.float32) blob = blob.reshape((1, height, width, 3)) # from nhwc to nchw blob = np.swapaxes(blob, 2, 3) blob = np.swapaxes(blob, 1, 2) if swap_channel: blob[:, [0, 2], :, :] = blob[:, [2, 0], :, :] if raw_scale != 1.0: blob *= raw_scale if isinstance(mean, np.ndarray): blob -= mean elif isinstance(mean, tuple) or isinstance(mean, list): blob[:, 0, :, :] -= mean[0] blob[:, 1, :, :] -= mean[1] blob[:, 2, :, :] -= mean[2] elif mean is None: pass else: raise TypeError, 'mean should be either a tuple or a np.ndarray' if input_scale != 1.0: blob *= input_scale return blob
blob = im2blob(im, width, height, mean=mean, swap_channel=True, raw_scale=raw_scale, input_scale=input_scale)outputs = mod.predict(blob)score = outputs[0][0]conv_fm = outputs[1][0]score_sort = -np.sort(-score)[:top_n]inds_sort = np.argsort(-score)[:top_n]
Localize the discriminative regions by analysing the class's response in the network's last conv feature map.
def get_cam(conv_feat_map, weight_fc): assert len(weight_fc.shape) == 2 if len(conv_feat_map.shape) == 3: C, H, W = conv_feat_map.shape assert weight_fc.shape[1] == C detection_map = weight_fc.dot(conv_feat_map.reshape(C, H*W)) detection_map = detection_map.reshape(-1, H, W) elif len(conv_feat_map.shape) == 4: N, C, H, W = conv_feat_map.shape assert weight_fc.shape[1] == C M = weight_fc.shape[0] detection_map = np.zeros((N, M, H, W)) for i in xrange(N): tmp_detection_map = weight_fc.dot(conv_feat_map[i].reshape(C, H*W)) detection_map[i, :, :, :] = tmp_detection_map.reshape(-1, H, W) return detection_map
plt.figure(figsize=(18, 6))plt.subplot(1, 1+top_n, 1)plt.imshow(rgb)cam = get_cam(conv_fm, weight_fc[inds_sort, :])for k in xrange(top_n): detection_map = np.squeeze(cam.astype(np.float32)[k, :, :]) heat_map = cv2.resize(detection_map, (width, height)) max_response = detection_map.mean() heat_map /= heat_map.max() im_show = rgb.astype(np.float32)/255*0.3 + plt.cm.jet(heat_map/heat_map.max())[:, :, :3]*0.7 plt.subplot(1, 1+top_n, k+2) plt.imshow(im_show) print 'Top %d: %s(%.6f), max_response=%.4f' % (k+1, synset[inds_sort[k]], score_sort[k], max_response)plt.show()
- class_active_maps
- Android WebView 微信和支付宝H5调用本地app支付
- 《Spring技术内幕》学习笔记17——Spring HTTP调用器实现远程调用
- 文章标题
- JSP根据参数默认选中radio
- 算法排序之堆排序
- class_active_maps
- unity2017:SpriteAltas 和SpriteMask
- AIDL灵活运用,避免创建多个Service
- 序列化二叉树java实现
- 《Spring技术内幕》学习笔记18——Spring使用Hessian实现远程调用
- Leetcode-Sort Colors
- java时间日期格式总结(一)
- 黄秀杰教程之--Node使用小程序模板消息
- linux命令分割、if语句、mv&cp、rm、export