Python调用已训练好的caffe模型进行分类

来源:互联网 发布:淘宝认证考试 编辑:程序博客网 时间:2024/04/30 23:51

python作为强大的解释型语言,其提供的库函数能够方便快速的实现常用的功能。本文用python调用caffe模型。


1.Setup

# set up Python environment: numpy for numerical routines, and matplotlib for plotting  import numpy as np  import matplotlib.pyplot as plt  # display plots in this notebook  %matplotlib inline    # set display defaults  plt.rcParams['figure.figsize'] = (10, 10)        # large images  plt.rcParams['image.interpolation'] = 'nearest'  # don't interpolate: show square pixels  plt.rcParams['image.cmap'] = 'gray'  # use grayscale output rather than a (potentially misleading) color heatmap  

2.Load caffe

# The caffe module needs to be on the Python path;  #  we'll add it here explicitly.  import sys  caffe_root = '/home/ubuntu/caffe/'  # this file should be run from {caffe_root}/examples (otherwise change this line)  sys.path.insert(0, caffe_root + 'python')    import caffe  # If you get "No module named _caffe", either you have not built pycaffe or you have the wrong path.


3.Import Net

caffe.set_mode_gpu()        model_def = './deploy.prototxt'  model_pretrained = './snapshot_iter_6720.caffemodel'# load the mean ImageNet image (as distributed with Caffe) for subtraction MEAN_PROTO_PATH = './mean.binaryproto'   #这里是二进制文件,而不是Python的npy文件blob = caffe.proto.caffe_pb2.BlobProto()data = open(MEAN_PROTO_PATH, 'rb' ).read()blob.ParseFromString(data)array = np.array(caffe.io.blobproto_to_array(blob))# 将blob中的均值转换成numpy格式,array的shape (mean_number,channel, hight, width)mu = array[0]mean = mu.mean(1).mean(1)  # average over pixels to obtain the mean (BGR) pixel values       net = caffe.Classifier(model_def, model_pretrained,mean=mean,                              channel_swap=(2,1,0),#RGB通道与BGR                                raw_scale=255,#把图片归一化到0~1之间                                image_dims=(256, 256))#设置输入图片的大小  

4.Classifier

label_list=['BAC','Caocx2','MUC','NEG','RBC','SPURM','WBC','XIAOYUAN','YEA','YISHUI']input_image = caffe.io.load_image('1.jpg')#读取图片#显示原图片,以及分类预测结果prediction = net.predict([input_image])#图片分类str_gender=label_list[prediction[0].argmax()]print str_genderplt.imshow(input_image)  plt.title(str_gender)  plt.show()




结果如下:

BAC


0 0