caffe的python接口学习:用训练好的模型(caffemodel或者h5)来分类新的图片

来源:互联网 发布:淘宝联盟能省钱 编辑:程序博客网 时间:2024/04/29 12:19

使用Python接口调用训练好的模型进行图像分类,需要准备以下文件:

(1)网络模型结构文件——deploy文件,该文件的生成可参考博文http://blog.csdn.net/u010417185/article/details/52137825

(2)已经训练好的模型——caffemodel或者h5都可以

(3)类别标签文件——label.txt,里面写有全部标签名称,格式如下所示:

                   

  本文以CIFAR10模型为例,进行介绍。具体代码如下:

#coding=utf-8import caffeimport numpy as nproot=root='/home/dltest/caffe/'   #根目录deploy=root + 'examples/cifar10/cifar10_quick.prototxt'    #deploy文件caffe_model=root + 'examples/cifar10/cifar10_quick_iter_4000.caffemodel.h5'  #训练好的 caffemodelimg=root+'examples/sgg_datas/images/1.jpg'   #随机找的一张待测图片labels_filename = root +'data/cifar10/batches.meta.txt'    #类别名称文件,将数字标签转换回类别名称net = caffe.Net(deploy,caffe_model,caffe.TEST)   #加载model和network#图片预处理设置transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})  #设定图片的shape格式(1,3,28,28)transformer.set_transpose('data', (2,0,1))    #改变维度的顺序,由原始图片(28,28,3)变为(3,28,28)#transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))    #减去均值,前面训练模型时没有减均值,这儿就不用transformer.set_raw_scale('data', 255)    # 缩放到【0,255】之间transformer.set_channel_swap('data', (2,1,0))   #交换通道,将图片由RGB变为BGRim=caffe.io.load_image(img)                   #加载图片net.blobs['data'].data[...] = transformer.preprocess('data',im)      #执行上面设置的图片预处理操作,并将图片载入到blob中#执行测试out = net.forward()labels = np.loadtxt(labels_filename, str, delimiter='\t')   #读取类别名称文件prob= net.blobs['prob'].data[0].flatten() #取出最后一层(prob)属于某个类别的概率值,并打印print proborder=prob.argsort()[9]  #将概率值排序,取出最大值所在的序号 #argsort()函数是从小到大排列print 'the class is:',labels[order]   #将该序号转换成对应的类别名称,并打印

注:上述程序中有两处需要格外注意

(1)prob= net.blobs['prob'].data[0].flatten() #取出最后一层(prob)属于某个类别的概率值,并打印

上述代码表示的意思取出最后一层(prob)属于某个类别的概率值,所以要看一下deploy文件中最后一层的名称是什么,根据deploy文件中的最后一层的名称来确定上述语句中红色部分的 “prob” 该怎样填写。等号前面的 prob 仅仅表示一个变量名称而已,不用做更改。

我的文件中最后一层的名称是“prob”。




(2)

order=prob.argsort()[9]   #将概率值排序,取出最大值所在的序号


上述代码中的数字 9 根据分类情况不同而发生变化。prob 是一个一维数组,里面存储的是图像im属于每一类的概率值,prob.argsort() 是将数组按照从小到大排列。所以数组 prob 的个数为分类数,由于数组序号从 0 开始,所以分成10类,则其最后一个的序号为9。若分成5类,则prob数组最后一个所在位置应该为[4]。

详细解释:

argsort() 函数的作用是对数组按照从小到大的顺序排列,prob 是一个一维数组,里面存储的是图像im属于每一类的概率值,所以 prob.argsort() 所得的是按照从小到大排列的概率值。所属类别是通过概率值决定,概率越大,属于该类别的可能就越大。所以应该选择其最大值,在CIFAR10中是分为10个类别的,Python数组的序号与C相同都是从 0 开始,所以本例中数组的最后一个的序号为 9,所以图像最大概率值所属的类别为 prob.argsort()[9]









1 0
原创粉丝点击