caffe下两种方法测试mnist

来源:互联网 发布:夏米尔火花机简单编程 编辑:程序博客网 时间:2024/06/07 00:23

1. 运行%caffe_root%\examples\mnist\test_lenet.sh

caffe.exe test -model=examples/mnist/lenet_train_test.prototxt -weights=examples/mnist/lenet_iter_10000.caffemodel

其中caffe.exe在目录%caffe_root%\build\tools\Release目录下。

2. 用python写脚本,通过predict函数预测

import sysimport caffeimport structimport timecaffe_root = r"D:\caffe"sys.path.insert(0, caffe_root+r'\python')test_MNIST_label = caffe_root + r"\data\mnist\t10k-labels-idx1-ubyte"# prepare labelindex = 0print("TEST LABEL DATA:" + test_MNIST_label)binfile = open(test_MNIST_label, 'rb')buf = binfile.read()magic, numImages = struct.unpack_from('>II', buf, index)print(magic, numImages)index += struct.calcsize('>II')label = []for i in range(numImages):    tmp, = struct.unpack_from('>B', buf, index)    index += struct.calcsize('>B')    label.append(tmp)# prepare image dataTEST_DATA_DIR = caffe_root + r"\examples\mnist\myTest\data\mnist_test"input_image = []for i in range(numImages):    TEST_DATA_FILE = TEST_DATA_DIR + r"\test_%s.bmp"%i    input_image.append(caffe.io.load_image(TEST_DATA_FILE, color=False))    if 0 == i % 1000:        print (len(input_image), TEST_DATA_FILE)# predictionstart = time.clock()MODEL_FILE = caffe_root + r"\examples\mnist\lenet.prototxt"PRETRAINED = caffe_root + r"\examples\mnist\lenet_iter_10000.caffemodel"net = caffe.Classifier(MODEL_FILE, PRETRAINED, image_dims = [28, 28])predi = []prediction = net.predict(input_image, oversample=False)end = time.clock()print("Done in time:%f(s)" % (end - start))for i in range(numImages):    #predi.append(prediction[i].argmax())    predi.append(prediction[i].flatten().argsort())    if 0 == i%1000:        print (i, prediction[i], predi[i])# get acceleratecount = 0wrong = []for i in range(numImages):    if label[i] in predi[i][-1:]:        count += 1    else:        wrong.append(i)print ("accelerate:%.2f[count = %d]%%"%(float(count)/numImages*100, count))print ("wrong:", wrong[:20])

* 测试结果

用1方法测试得accuracy=98.68%,用2方法测试得accuracy=99.09%。两种方法精度并不相同,而训练结束时测试得accuracy=99.08%也与之不同。难到是数据输入时有转换误差?或是train_test网络与deploy网络有所不同?或是用不同的patch会影响测试结果?原因还有待调查。

原创粉丝点击