Learning Deep Features for Discriminative Localization论文笔记以及Caffe实现

来源:互联网 发布:win10软件很模糊 编辑:程序博客网 时间:2024/06/07 02:50

首先说一下,作者是通过Caffe的MATLAB版本实现,这里使用Caffe的Python版本实现。
这里先放一下效果图,可以很容易的理解文章的意思。
这里写图片描述
这里写图片描述
这里写图片描述
这篇论文的主要贡献在于,可以通过热成像图的结果完成localization,并且网络本身也可以分类,作者说虽然在classification上的精度有所下降,但是效果也还是不错,在ImageNet上的top-5 test error 是37.1%,和 AlexNet的 34.2%比较接近。通过热成像图也可以帮助我们更好的了解网络训练的情况,是否出现过拟合。
作者使用了Average Pool,而没有使用常用的Max Pool , 原文的解释如下,they apply global max pooling to localize
a point on objects. However, their localization is limited to a point lying in the boundary of the object rather than determining the full extent of the object. We believe that while the max and average functions are rather similar, the use of average pooling encourages the network to identify the complete extent of the object。大概意思就是说,作者认为使用Average Pool可以更好的表示位置信息。
通过下图可以看出卷积神经网络在识别的过程中,先是又卷积层获取图片的位置信息并提取特征,它可以想人一样,找到图片中最关键的部位,根据这些信息进行识别。
这里写图片描述
下面这张图展示了对于不同的分类结果,卷积层的输出。
这里写图片描述

论文里面写了算法的实现,写的比较复杂,其实实现起来还是比较简单。
这里写图片描述
基本上看这张图就可以了,作者将VGG、GoogleNet、AlexNet的最后几层去掉,一般是去掉的全连接层,具体哪一层可以在论文里看,然后作者加入3*3,stride为1,pad为1的卷积层,之后是14*14的Average Pool层,然后就是全连接层,对这个新的网络进行fine-tuned。等到收敛之后,通过每一层卷积层的输出乘以这一层对应分类的权重,然后对结果加权,就可以得到热成像图,叠加在原图上就可以有以上的效果。

代码实现如下:
可以通过python xxx.py 243这样来运行,243是imagenet中bull mastiff 的索引值。
效果图就是文章开头用到的三张图。
代码以及配置文件等之后会上传到github。

#coding:utf-8import caffeimport numpy as npimport cv2import matplotlib.pylab as pltimport sysif __name__ == '__main__':    CLASS = int(sys.argv[1])    net = caffe.Net("./deploy_googlenetCAM.prototxt",                    "./imagenet_googleletCAM_train_iter_120000.caffemodel",                    caffe.TEST)    net.forward()    weights = net.params['CAM_fc'][0].data    img = net.blobs['CAM_conv'].data[0]    weights = np.array(weights,dtype=np.float)    img = np.array(img,dtype=np.float)    new_img = np.zeros([14,14],dtype = np.float)    for i in range(1024):        w = weights[CLASS][i]        tmp =  w*img[i]        new_img += tmp    new_img = cv2.resize(new_img,(224,224))    src = cv2.imread('./cat_dog.jpg')    src = cv2.cvtColor(src,cv2.COLOR_BGR2RGB)    src = cv2.resize(src,(224,224))    new_img = 100*new_img    plt.imshow(src)    plt.imshow(new_img,alpha=0.5, interpolation='nearest')    plt.show()
0 0
原创粉丝点击