caffe的python接口学习(8):caffemodel中的参数及特征的抽取

来源:互联网 发布:企业qq与qq的端口号 编辑:程序博客网 时间:2024/05/16 01:05

原文链接:http://www.cnblogs.com/denny402/p/5686257.html

如果用公式  y=f(wx+b)

来表示整个运算过程的话,那么w和b就是我们需要训练的东西,w称为权值,在cnn中也可以叫做卷积核(filter),b是偏置项。f是激活函数,有sigmoid、relu等。x就是输入的数据。

数据训练完成后,保存的caffemodel里面,实际上就是各层的w和b值。

我们运行代码:

deploy=root + 'mnist/deploy.prototxt'    #deploy文件caffe_model=root + 'mnist/lenet_iter_9380.caffemodel'   #训练好的 caffemodelnet = caffe.Net(net_file,caffe_model,caffe.TEST)   #加载model和network

就把所有的参数和数据都加载到一个net变量里面了,但是net是一个很复杂的object,想直接显示出来看是不行的。其中:

net.params: 保存各层的参数值(w和b)

net.blobs: 保存各层的数据值

可用命令:

[(k,v[0].data) for k,v in net.params.items()]

查看各层的参数值,其中k表示层的名称,v[0].data就是各层的W值,而v[1].data是各层的b值。注意:并不是所有的层都有参数,只有卷积层和全连接层才有。

也可以不查看具体值,只想看一下shape,可用命令

[(k,v[0].data.shape) for k,v in net.params.items()]

假设我们知道其中第一个卷积层的名字叫'Convolution1', 则我们可以提取这个层的参数:

w1=net.params['Convolution1'][0].datab1=net.params['Convolution1'][1].data

输入这些代码,实际查看一下,对你理解network非常有帮助。

同理,除了查看参数,我们还可以查看数据,但是要注意的是,net里面刚开始是没有数据的,需要运行:

net.forward()

之后才会有数据。我们可以用代码:

[(k,v.data.shape) for k,v in net.blobs.items()]

[(k,v.data) for k,v in net.blobs.items()]

来查看各层的数据。注意和上面查看参数的区别,一个是net.params, 一个是net.blobs.

实际上数据刚输入的时候,我们叫图片数据,卷积之后我们就叫特征了。

如果要抽取第一个全连接层的特征,则可用命令:

fea=net.blobs['InnerProduct1'].data

只要知道某个层的名称,就可以抽取这个层的特征。

推荐大家在spyder中,运行一下上面的所有代码,深入理解模型各层。

最后,总结一个代码:

复制代码
import caffeimport numpy as nproot='/home/xxx/'   #根目录deploy=root + 'mnist/deploy.prototxt'    #deploy文件caffe_model=root + 'mnist/lenet_iter_9380.caffemodel'   #训练好的 caffemodelnet = caffe.Net(deploy,caffe_model,caffe.TEST)   #加载model和network
[(k,v[0].data.shape) for k,v in net.params.items()]  #查看各层参数规模
w1=net.params['Convolution1'][0].data  #提取参数wb1=net.params['Convolution1'][1].data  #提取参数b
net.forward()   #运行测试
[(k,v.data.shape) for k,v in net.blobs.items()]  #查看各层数据规模
fea=net.blobs['InnerProduct1'].data   #提取某层数据(特征)
复制代码


补充自:http://blog.csdn.net/guoyilin/article/details/42886365

本文主要对http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/filter_visualization.ipynb进行代码解析。
1. net.blobs.items() 存储了预测图片的网络中各层的feature map的数据。
2. net.params.items()存储了训练结束后学习好的网络参数。
3. vis_square 函数视觉化data,主要是进行数据归一化,data转换为plt可视化的square结构。
4. 
7.

plt.imshow(net.deprocess('data', net.blobs['data'].data[4]))

这里的4是第4个crop,图片会被crop成10个227*227.

5. net.params['conv1'][0].data, 这是表示conv1层的w参数

  net.params['conv2'][1].data, 这是表示conv1层的b参数

6.

filters.transpose(0, 2, 3, 1)对filters 4维数组进行位置对换,主要是为了将rgb放在最后一维。

net.blobs['conv1'].data[4, :36] 表示conv1层学习的feature map, 显示第4个crop image的top 36个feature map。

8. 

filters = net.params['conv2'][0].data

filters[:48].reshape(48**2, 5, 5) 对conv2 层参数w进行显示, conv2 :256 * 48 * 5 * 5, 这里显示头48个filters, reshape是为了在显示的时候把48个5*5的kernel放在一行显示,共48*48的方格显示。


0 0
原创粉丝点击