MatConvNet框架下mnist数据集测试

来源:互联网 发布:fittime软件好用吗. 编辑:程序博客网 时间:2024/06/07 06:59

当cnn_mnist.m运行完成后,我们再打开data文件夹里的mnist-baseline-simplenn文件夹,就会发现里面多了一个pdf文件和20个net-epoch-(1~20).mat,这20个net-epoch-(1~20).mat,就是经过每一轮训练后,获得的训练好的模型。
如果在训练的时候选择了opts.batchNormalization为true的话,即进行批量归一化,那么生成的文件夹便是mnist-baseline-simplenn-bnorm,文件夹下也会有20个模型。在测试的时候,如果使用此模型,并且对图像仅仅是进行了归一化和减去均值操作,那么测试便得不到想要的结果。
在此按照ImageNet测试的demo写了一个mnist测试的代码,有关注意事项在代码中说明

run ../matlab/vl_setupnnload('../data\mnist-baseline-simplenn/net-epoch-20.mat');%此模型包含三个部分,其中一部分为netload('../data\mnist-baseline-simplenn-bnorm/imdb.mat');%images结构体在此读取net = vl_simplenn_tidy(net);net.layers{1,end}.type = 'softmax';%训练时为softmaxloss,测试时为softmaxtest_index = find(images.set==3);%1对应训练集,3对应测试集,1有(1——60000)3有(60001——70000)% 挑选出测试集以及真实类别test_data = images.data(:,:,:,test_index);test_label = images.labels(test_index);im_ = test_data(:,:,:,536);%随意选取一张图像% im=imread('5.jpg');% im_=single(im);im_=imresize(im_,net.meta.inputSize(1:2));%此处和ImageNet网络名称不同im_ = im_ - images.data_mean;去均值% im_=im_-net.meta.normalization.averageImage;res=vl_simplenn(net,im_);y=res(end).x;x=gather(res(end).x);scores=squeeze(gather(res(end).x));[bestScore,best]=max(scores);figure(1);clf;imshow(im_);title(sprintf('%s %d,%.3f',...        net.meta.classes.name{best-1},best-1,bestScore));

另外还有一个对序列号为60000-70000图像进行整体精度预测的代码,大致思路与上面相同

run ../matlab/vl_setupnnload('../data\mnist-baseline-simplenn/net-epoch-11.mat');%此处换成自己下载模型存储的位置load('../data\mnist-baseline-simplenn-bnorm/imdb.mat');net = vl_simplenn_tidy(net);net.layers{1,end}.type = 'softmax';%训练时为softmaxloss,测试时为softmax% 挑选出测试样本在全体数据中对应的编号60001-70000test_index = find(images.set==3);%1对应训练集,3对应测试集,1有(1——60000)3有(60001——70000)% 挑选出测试集以及真实类别test_data = images.data(:,:,:,test_index);test_label = images.labels(test_index);% 将最后一层改为 softmax (原始为softmaxloss,这是训练用)net.layers{1, end}.type = 'softmax';% 对每张测试图片进行分类for i = 1:length(test_label)    i    im_ = test_data(:,:,:,i);    im_ = im_ - images.data_mean;    res = vl_simplenn(net, im_) ;    scores = squeeze(gather(res(end).x)) ;    [bestScore, best] = max(scores) ;    pre(i) = best;end% 计算准确率accurcy = length(find(pre==test_label))/length(test_label);disp(['accurcy = ',num2str(accurcy*100),'%']);
阅读全文
0 0
原创粉丝点击