深度学习(十三):Matconvnet详解与实验手写体数据库

来源:互联网 发布:淘宝开店交易手续费 编辑:程序博客网 时间:2024/05/21 17:18

手写体数据库是一个简单通用的模型,这是进一步理解像imagenet的cnn模型的基础模型。

关于手写体,输入的大小就是28*28的黑白二像素图像,比较简单,ok现在开始操作。

Matconvnet自带集成mnist这个实例库,数据集的下载都帮集成了,我们只需要去确认工具箱可以用,然后直接运行就可以了。

打开安装文件夹以后打开如下函数:

这里写图片描述

首先将整个安装包添加路径,然后直接运行就可以了。等一段时间后(训练20代,大概10多分钟),训练完成了,结果都存在上面那个文件夹的data里面,像下面这样:
这里写图片描述

训练的效果如图:
这里写图片描述

这样也只能看出结果收敛的好,但是具体准确率多少呢?后面介绍。

好了下面再开看看这个网络构造,自己去cnn_mnist_init这个函数就可以发现模型结构了。
贴一个解释:

MatConvNet中mnist源码解析

可以自己去看代码,就不解释了。这里假设代码看懂了,那么我抽象出来这个模型结构就像下面这样:
这里写图片描述

可以这个模型的结构依次为:输入-卷积1-pool1-卷积2-pool2-卷积2-relu-卷积4(也可以是全连接层)-sofrmax分类。

根据代码中的结构定义,我们也可以推算出每一层的输出大小(上述图)。

这里需要注意的是,这个网络层的设计大小一定要吻合。比如,经过一系列的卷积pool后,到全连接层时,输入一定要是1*1*X*X,且上一层的map,和下一层的卷积数map一定要一样。比如输入为28*28*1,那么第一个卷积核是5*5*1*20,这个1就是上一层的只有一个图。再往下走卷积为5*5*20*50,这里为什么是20,因为上一层的map有20个。一次类推,计算到最后正好为1*1*500,这样才可以全连接层。如果你自己设计网络,自己计算一下一定要保证后面全连接层时输入为1*1*X*X,否则会训练错误。

Ok,这里说到这,下面在更直观看下网上别人分享的一个网络(当然输入大小错了,中间大小也错了,但是结构类似)

这里写图片描述

好了,模型训练完了,那么怎么测试呢?
这个数据集本身也分了训练集与测试集,下面我们来测试一下测试集的准确率。建一个m脚本,函数如下:

% 导入全体数据load('D:\myself\matlab\matlab_documents\matconvnet_test\matconvnet\data\mnist-baseline-simplenn\imdb.mat');% 挑选出测试集test_index = find(images.set==3);% 挑选出样本以及真实类别test_data = images.data(:,:,:,test_index);test_label = images.labels(test_index);%导入模型文件load('D:\myself\matlab\matlab_documents\matconvnet_test\matconvnet\data\mnist-baseline-simplenn\net-epoch-20.mat');% 将最后一层改为 softmax (原始为softmaxloss,这是训练用)net.layers{1, end}.type = 'softmax';% net = vl_simplenn_tidy(net) ;for i = 1:length(test_label)    i    im_ = test_data(:,:,:,i);    im_ = im_ - images.data_mean;    res = vl_simplenn(net, im_) ;    scores = squeeze(gather(res(end).x)) ;    [bestScore, best] = max(scores) ;    pre(i) = best;end% 计算准确率accurcy = length(find(pre==test_label))/length(test_label);disp(['accurcy = ',num2str(accurcy*100),'%']);

解释一下,中间有一块需要改变最后一层的名字变为softmax,这样才是测试用的。

测试样本可以看到为10000,这样运行一下结果如下:

accurcy = 96.97%

可见准确率还是挺高的。在这个序列的前几篇,曾经也用另外一个深度学习工具箱做过mnist的实验(感兴趣可以去看)

深度学习系列(八):自编码网络多层特征学习

在那里用的是自编码以及pca等方法,都不能得到像现在这样的准确率。


进一步索引对网上大牛对Matconvnet中一些函数的详细解释

MatConvNet 庖丁解牛

注释详细至极,同时这个网站的众多好资源,表示感谢。


最后分享一个网上将这个代码转变为c++代码的测试方法,觉得很好。

C++使用matlab卷积神经网络库MatConvNet来进行手写数字识别

5 0
原创粉丝点击