pycaffe的使用

来源:互联网 发布:中科院在职研究生知乎 编辑:程序博客网 时间:2024/04/26 13:59

caffe的官方完美的支持python语言的兼容,提供了pycaffe的接口。用起来很方便,首先来看一下最常用到的:caffe的一个程序跑完之后会在snapshot所指定的目录下产生一个后缀名为caffemodel的文件,这里存放的就是我们在训练网络的时候得到的每层的参数信息,具体访问由net.params['layerName'][0].data访问权重参数(num_filter,channel,weight,high),net.params['layerName'][1].data访问biase,格式是(biase,)。如下图所示:这里的net.params使用的是字典格式


     当然还有保存网络结构的字典类型net.blobs['layerName'].data。这里最常用的也就是net.blobs['data']相关的使用,例如得到输入图片的大小net.blobs['data'].data.shape。改变输入图片的大小net.blobs['data'].reshape(0,3,227,227),把图片fed into网络。net.blob['data'].data[...]=inputImage,注意,这里最后一个data是一个数组,要是只有一张图片就这样net.blob['data'].data[0]=inputImage。如下图所示:


    下面用python实现一个使用自己的图片的例子:

[python] view plain copy
  1. import numpy as np  
  2. import sys,os  
  3. # 设置当前的工作环境在caffe下  
  4. caffe_root = '/home/xxx/caffe/'   
  5. # 我们也把caffe/python也添加到当前环境  
  6. sys.path.insert(0, caffe_root + 'python')  
  7. import caffe  
  8. os.chdir(caffe_root)#更换工作目录  
  9.   
  10. # 设置网络结构  
  11. net_file=caffe_root + 'models/bvlc_reference_caffenet/deploy.prototxt'  
  12. # 添加训练之后的参数  
  13. caffe_model=caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'  
  14. # 均值文件  
  15. mean_file=caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy'  
  16.   
  17. # 这里对任何一个程序都是通用的,就是处理图片  
  18. # 把上面添加的两个变量都作为参数构造一个Net  
  19. net = caffe.Net(net_file,caffe_model,caffe.TEST)  
  20. # 得到data的形状,这里的图片是默认matplotlib底层加载的  
  21. transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})  
  22. # matplotlib加载的image是像素[0-1],图片的数据格式[weight,high,channels],RGB  
  23. # caffe加载的图片需要的是[0-255]像素,数据格式[channels,weight,high],BGR,那么就需要转换  
  24.   
  25. # channel 放到前面  
  26. transformer.set_transpose('data', (2,0,1))  
  27. transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))  
  28. # 图片像素放大到[0-255]  
  29. transformer.set_raw_scale('data'255)   
  30. # RGB-->BGR 转换  
  31. transformer.set_channel_swap('data', (2,1,0))  
  32.   
  33. # 这里才是加载图片  
  34. im=caffe.io.load_image(caffe_root+'examples/images/cat.jpg')  
  35. # 用上面的transformer.preprocess来处理刚刚加载图片  
  36. net.blobs['data'].data[...] = transformer.preprocess('data',im)  
  37. #注意,网络开始向前传播啦  
  38. out = net.forward()  
  39. # 最终的结果: 当前这个图片的属于哪个物体的概率(列表表示)  
  40. output_prob = output['prob'][0]  
  41. # 找出最大的那个概率  
  42. print 'predicted class is:', output_prob.argmax()  
  43.   
  44. # 也可以找出前五名的概率  
  45. top_inds = output_prob.argsort()[::-1][:5]    
  46. print 'probabilities and labels:'  
  47. zip(output_prob[top_inds], labels[top_inds])  
  48.   
  49. # 最后加载数据集进行验证  
  50. imagenet_labels_filename = caffe_root + 'data/ilsvrc12/synset_words.txt'  
  51. labels = np.loadtxt(imagenet_labels_filename, str, delimiter='\t')  
  52.   
  53. top_k = net.blobs['prob'].data[0].flatten().argsort()[-1:-6:-1]  
  54. for i in np.arange(top_k.size):  
  55.     print top_k[i], labels[top_k[i]]  

[python] view plain copy
  1. import os  
  2. import numpy as np   
  3. import os  
  4. import matplotlib.pyplot as plt   
  5. import matplotlib.patches as  mpatches  
  6. %matplotlib inline  
  7.   
  8. # 设置默认的属性:用于在ipython中显示图片  
  9. plt.rcParams['figure.figsize'] = (1010)          
  10. plt.rcParams['image.interpolation'] = 'nearest'    
  11. plt.rcParams['image.cmap'] = 'gray'    
  12. from math import pow  
  13. from skimage import transform as tf   
  14.   
  15. caffe_root='/opt/modules/caffe-master/'  
  16. sys.insert.path(0,caffe_root+'python')  
  17.   
  18. caffe_modelcaffe=caffe_root+''  
  19. caffe_deploy=caffe_root+''  
  20.   
  21. caffe.set_mode_cpu()  
  22. net=caffe.Net(caffe_deploy,caffe_modelcaffe,caffe.TEST)  
  23.   
  24.   
  25. transform=caffe.io.Transformer({'data':net.blobs['data'].data.shape})  
  26. transform.set_transpose('data',(2,0,1))  
  27. transform.set_raw_scale('data',255)  
  28. transform.set_channel_swap('data',(2,1,0))  
  29.   
  30. #把加载到的图片缩放到固定的大小  
  31. net.blobs['data'].reshape(1,2,227,227)  
  32.   
  33. image=caffe.io.load_image('/opt/data/person/1.jpg')  
  34. transformed_image=transform.preprocess('data',image)  
  35. plt.inshow(image)  
  36.   
  37. # 把警告过transform.preprocess处理过的图片加载到内存  
  38. net.blobs['data'].data[...]=transformed_image  
  39.   
  40. output=net.forward()  
  41.   
  42. #因为这里仅仅测试了一张图片  
  43. #output_pro的shape中有对于1000个object相似的概率  
  44. output_pro=output['prob'][0]  
  45.   
  46. #从候选的区域中找出最有可能的那个object的索引  
  47. output_pro_max_index=output_pro.argmax()  
  48.   
  49. labels_file = caffe_root + '.../synset_words.txt'  
  50. if not os.path.exists(labels_file):  
  51.     print "in the direct without this synset_words.txt "  
  52.     return   
  53. labels=np.loadtxt(labels_file,str,delimiter='\t')  
  54.   
  55. # 从对应的索引文件中找到最终的预测结果  
  56. outpur_label=labels[output_pro_max_index]  
  57. # 也可以找到排名前五的预测结果  
  58. top_five_index=output_pro.argsort()[::-1][:5]  
  59. print 'probabilities and labels:'  
  60. zip(output_pro[top_five_index],labels[top_five_index])  
原创粉丝点击