caffe入门4:使用训练好的模型对数据分类

来源:互联网 发布:java urlencode解码 编辑:程序博客网 时间:2024/06/18 15:40

按照之前的设置,我们已经训练好了的模型文件位于$caffe/models/mydata 。
名为caffenet_train_iter_*.caffemodel
我们通过下面的python脚本来利用模型测试其在测试集外的数据的准确率。代码有点丑请见谅。

import numpy as npimport sys,oscaffe_root = '/home/will/deepLearning/caffe-ssd/' sys.path.insert(0, caffe_root + 'python')import caffeos.chdir(caffe_root)net_file=caffe_root + 'models/yangshuang/deploy.prototxt'caffe_model=caffe_root + 'models/yangshuang/caffenet_train_iter_40000.caffemodel'mean_file=caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy'net = caffe.Net(net_file,caffe_model,caffe.TEST)transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})transformer.set_transpose('data', (2,0,1))transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))transformer.set_raw_scale('data', 255) transformer.set_channel_swap('data', (2,1,0))file=open("$caffe/data/mydata/test.txt")path="$caffe/data/mydata/test/"count=0a_count=0for line in file:   pic_name=line.split(" ")[0]  label=int(line.split(" ")[1].strip("\n").strip())  im=caffe.io.load_image(path+pic_name)  net.blobs['data'].data[...] = transformer.preprocess('data',im)  out = net.forward()#magenet_labels_filename = caffe_root + 'data/ilsvrc12/synset_words.txt'#labels = np.loadtxt(imagenet_labels_filename, str, delimiter='\t')  top_1 = net.blobs['prob'].data[0].flatten().argsort()[-1]  if top_1==label:    print "true"    a_count=a_count+1  else:    print pic_name  count=count+1print "count:"+str(count) print "a_count:"+str(a_count)file.close()

其中,test.txt内包含图片位置信息以及其对应的标签。我们通过比较预测值和真实值来判断我们的分类是否准确。
输出的count为所有测试图片的数目。
a_count为识别准确的数目。
我使用迭代30000次的模型,测试集的准确率为99.1%,这次较小的数据测试中,准确率为98.319%,偏差不大。

zuoqianluntai/54f919d0d53b6873502.jpg 6zuoqianluntai/54f91ae477a49930823.jpg 6zuoqianluntai/54f91c3818d23717626.jpg 6zuoqianluntai/54f919d2b39e8606663.jpg 6zuoqianluntai/54f909781e5be111303.jpg 6zuoqianluntai/54f90c962b62c641014.jpg 6zuoqianluntai/54f915f3436c6806515.jpg 6zuoqianluntai/54f919752883b784886.jpg 6zuoqianluntai/54f91d10f169c520858.jpg 6

tips:1.测试数据也需要进行256*256的缩放。
2.路径一定要根据自己的文件目录写,写成绝对路径是最靠谱的。

0 0
原创粉丝点击