MatConvNet 框架的mnist实例

来源:互联网 发布:华为java面试题2016 编辑:程序博客网 时间:2024/06/14 22:06

mnist  手写是被

cnn_mnist.m 主函数代码:

function [net, info] = cnn_mnist(varargin)  
% --------------------------------------------------------------   
%   主函数:cnn_mnist  
%   功能:  1.初始化CNN  
%           2.设置各项参数  
%           3.读取和保存数据集  
%           4.初始化train mnist的主函数 
%   参数:   varargin 可变参数
%   返回值: net  info
% ------------------------------------------------------------------------  
%运行vl_setupnn.m
run(fullfile(fileparts(mfilename('fullpath')),...
  '..', '..', 'matlab', 'vl_setupnn.m')) ;
% 参数配置
opts.batchNormalization = false ;                   %选择batchNormalization的真假  
opts.network = [] ;                                 %初始化一个网络  
opts.networkType = 'simplenn' ;                     %选择封装器:simplenn ,封装器有两种,分别为simplenn 和 dagnn  
[opts, varargin] = vl_argparse(opts, varargin) ;    %调用vl_argparse函数  参数值对的解析列表每初始化一次,就调用一次该函数
  
% 数据存放的路径
sfx = opts.networkType ;                                                %sfx=simplenn  
if opts.batchNormalization, sfx = [sfx '-bnorm'] ; end                  %这里条件为假  
opts.expDir = fullfile(vl_rootnn, 'data', ['mnist-baseline-' sfx]) ;    %选择数据存放的路径:vl_rootnn表示根目录,data\mnist-baseline-simplenn  
[opts, varargin] = vl_argparse(opts, varargin) ;                        %调用vl_argparse函数  
  
% 数据的读取路径
opts.dataDir = fullfile(vl_rootnn, 'data', 'mnist') ;                   %选择数据读取的路径:data\matconvnet-1.0-beta23\data\mnist  
opts.imdbPath = fullfile(opts.expDir, 'imdb.mat');                      %选择imdb结构体的路径:data\data\mnist-baseline-simplenn\imdb  
opts.train = struct() ;                                                 %选择训练集返回为struct型  
opts = vl_argparse(opts, varargin) ;                                    %调用vl_argparse函数  
  
%选择是否使用GPU,使用opts.train.gpus = 1,不使用:opts.train.gpus = []。   
%if ~isfield(opts.train, 'gpus'), opts.train.gpus = 1; end;    
if ~isfield(opts.train, 'gpus'), opts.train.gpus = []; end;


% ------------------------------------------------------------------------- 
%                                                              准备网络  
% -------------------------------------------------------------------------  
if isempty(opts.network)                                                    %如果原网络为空:  
  net = cnn_mnist_init('batchNormalization', opts.batchNormalization, ...   %则调用cnn_mnist_init网络结构  
    'networkType', opts.networkType) ;  
else                                                                        %否则:  
  net = opts.network ;                                                      %使用上面选择的数值带入现有网络  
  opts.network = [] ;  
end  
  
% -------------------------------------------------------------------------  
%                                                              准备数据  
% -------------------------------------------------------------------------  
if exist(opts.imdbPath, 'file')                         %如果mnist中存在imdb的结构体:  
  imdb = load(opts.imdbPath) ;                          %   载入imdb  
else                                                    %否则:  
  imdb = getMnistImdb(opts) ;                           %   调用getMnistImdb函数得到imdb并保存  
  mkdir(opts.expDir) ;                                    
  save(opts.imdbPath, '-struct', 'imdb') ;  
end  
  
%arrayfun函数通过应用sprintf函数得到array中从1到10的元素并且将其数字标签转化为char文字型  
net.meta.classes.name = arrayfun(@(x)sprintf('%d',x),1:10,'UniformOutput',false) ;  
  
% -------------------------------------------------------------------------  
%                                                              开始训练  
% -------------------------------------------------------------------------  
  
switch opts.networkType                                     %选择网络类型:  
  case 'simplenn', trainfn = @cnn_train ;                   %   1.simplenn  
  case 'dagnn', trainfn = @cnn_train_dag ;                  %   2.dagnn  
end  
  
[net, info] = trainfn(net, imdb, getBatch(opts), ...        %调用训练函数,开始训练:find(imdb.images.set == 3)为验证集的样本  
  'expDir', opts.expDir, ...                                % 参数的有关配置
  net.meta.trainOpts, ...  
  opts.train, ...  
  'val', find(imdb.images.set ==3)) ;  
  
  
% ------------------------------------------------------------------------  
function fn = getBatch(opts)  
% --------------------------------------------------------------  
%   函数名:getBatch  batch 批
%   功能:  1.由opts返回函数  
%           2.从imdb结构体取出数据   
% ------------------------------------------------------------------------  
switch lower(opts.networkType)                              %根据网络类型使用不同的getBatcch  
  case 'simplenn'  
    fn = @(x,y) getSimpleNNBatch(x,y) ;                     % 句柄函数(matlab基础知识)
  case 'dagnn'  
    bopts = struct('numGpus', numel(opts.train.gpus)) ;  
    fn = @(x,y) getDagNNBatch(bopts,x,y) ;  
end  
  
  
% --------------------------------------------------------------------  
function [images, labels] = getSimpleNNBatch(imdb, batch)  
% --------------------------------------------------------------  
%   函数名:getSimpleNNBatch  
%   功能:  1.由SimpleNN网络的批得到函数  
%           2.batch为样本的索引值  
% ------------------------------------------------------------------------  
images = imdb.images.data(:,:,:,batch) ;                %返回训练集  数据data,进行处理,最后会以矩阵的形式出现
labels = imdb.images.labels(1,batch) ;                  %返回集标签  
  
% --------------------------------------------------------------------  
function inputs = getDagNNBatch(opts, imdb, batch)  
% --------------------------------------------------------------  
%   函数名:getDagNNBatch  
%   功能:  类似上面的函数,这里的网络结构是DagNN  
% ------------------------------------------------------------------------  
images = imdb.images.data(:,:,:,batch) ;  
labels = imdb.images.labels(1,batch) ;  
if opts.numGpus > 0                                     %使用GPU进行并行运算  
  images = gpuArray(images) ;  
end  
inputs = {'input', images, 'label', labels} ;             
  
% --------------------------------------------------------------------  
function imdb = getMnistImdb(opts)            % 数据data进行处理
%--------------------------------------------------------------  
%   函数名:getMnistImdb  
%   功能:  1.从mnist数据集中获取data  
%           2.将得到的数据减去mean值,为了减少数据处理量 
%           3.将处理后的数据存放如imdb结构中  
% ------------------------------------------------------------------------  
% Preapre the imdb structure, returns image data with mean image subtracted  
files = {'train-images-idx3-ubyte', ...                     %载入mnist数据集  
         'train-labels-idx1-ubyte', ...  
         't10k-images-idx3-ubyte', ...  
         't10k-labels-idx1-ubyte'} ;  
  
if ~exist(opts.dataDir, 'dir')                              %如果不存在读取路径:  
  mkdir(opts.dataDir) ;                                     %   建立读取路径  
end  
  
for i=1:4                                                   %如果不存在mnist数据集则下载  
  if ~exist(fullfile(opts.dataDir, files{i}), 'file')  
    url = sprintf('http://yann.lecun.com/exdb/mnist/%s.gz',files{i}) ;  
    fprintf('downloading %s\n', url) ;  
    gunzip(url, opts.dataDir) ;  
  end  
end  
  
f=fopen(fullfile(opts.dataDir, 'train-images-idx3-ubyte'),'r') ;    %载入第一个文件,训练数据集大小为28*28,数量为6万  
x1=fread(f,inf,'uint8');                                              
fclose(f) ;   
x1=permute(reshape(x1(17:end),28,28,60e3),[2 1 3]) ;                %通过permute函数将数组的维度由原来的[1 2 3]变为[2 1 3] ...  
                                                                    %reshape将原数据从第17位开始构成28*28*60000的数组  
  
f=fopen(fullfile(opts.dataDir, 't10k-images-idx3-ubyte'),'r') ;     %载入第二个文件,测试数据集大小为28*28,数量为1万  
x2=fread(f,inf,'uint8');  
fclose(f) ;  
x2=permute(reshape(x2(17:end),28,28,10e3),[2 1 3]) ;                %同上解释  
  
f=fopen(fullfile(opts.dataDir, 'train-labels-idx1-ubyte'),'r') ;    %载入第三个文件:训练数据集的类标签  
y1=fread(f,inf,'uint8');  
fclose(f) ;  
y1=double(y1(9:end)')+1 ;                                                                                    
  
f=fopen(fullfile(opts.dataDir, 't10k-labels-idx1-ubyte'),'r') ;     %载入第四个文件:测试数据集的类标签  
y2=fread(f,inf,'uint8');  
fclose(f) ;  
y2=double(y2(9:end)')+1 ;  
  
%set = 1 对应训练;set = 3 对应的是测试  
set = [ones(1,numel(y1)) 3*ones(1,numel(y2))];              %numel返回元素的总数  
data = single(reshape(cat(3, x1, x2),28,28,1,[]));          %将x1的训练数据集和x2的测试数据集的第三个维度进行拼接组成新的数据集,并且转为single型减少内存  
dataMean = mean(data(:,:,:,set == 1), 4);                   %求出训练数据集中所有的图像的均值  
data = bsxfun(@minus, data, dataMean) ;                     %利用bsxfun函数将数据集中的每个元素逐个减去均值  
  
%将数据存入imdb结构中  
imdb.images.data = data ;                                   %data的大小为[28 28 1 70000]。 (60000+10000)  
imdb.images.data_mean = dataMean;                           %dataMean的大小为[28 28]  
imdb.images.labels = cat(2, y1, y2) ;                       %拼接训练数据集和测试数据集的标签,拼接后的大小为[1 70000]  
imdb.images.set = set ;                                     %set的大小为[1 70000],unique(set) = [1 3]  
imdb.meta.sets = {'train', 'val', 'test'} ;                 %imdb.meta.sets=1用于训练,imdb.meta.sets=2用于验证,imdb.meta.sets=3用于测试  
  
%arrayfun函数通过应用sprintf函数得到array中从0到9的元素并且将其数字标签转化为char文字型  
imdb.meta.classes = arrayfun(@(x)sprintf('%d',x),0:9,'uniformoutput',false) ; 


cnn_test_mnist.m 测试集代码

function [ net,info ] = cnn_mnist_test( varargin )
%CNN_MNIST_TEST 此处显示有关此函数的摘要
%    函数名:cnn_init_test
%    功能:  进行数据测试
%    返回值:  net info
%    参数:  varagrgin


% 加载vl_setupnn
run   matlab\vl_setupnn


% 导入数据model
load('E:\matconvnet-1.0-beta24\matconvnet-1.0-beta24\data\mnist-baseline-simplenn\net-epoch-20.mat');%此模型包含三个部分,其中一部分为net
% 导入数据集
load('E:\matconvnet-1.0-beta24\matconvnet-1.0-beta24\data\mnist-baseline-simplenn\imdb.mat');%images结构体在此读取


net = vl_simplenn_tidy(net);


net.layers{1,end}.type = 'softmax';%训练时为softmaxloss,测试时为softmax


% 挑选出测试集
test_index = find(images.set==3);%1对应训练集,3对应测试集,1有(1——60000)3有(60001——70000)


% 挑选出测试集以及真实类别
test_data = images.data(:,:,:,test_index);
test_label = images.labels(test_index);


% for i = 1:length(test_label)
% %     im_ = test_data(:,:,:,6010);%随意选取一张图像
%         im_ = test_data(:,:,:,i);  %随意选取一张图像
% end




im_ = test_data(:,:,:,666);%随意选取一张图像
% im=imread('5.jpg');


% 将im_中转换为单精度类型
im_=single(im_);
% 归一化大小 将图片缩放到28 * 28 的大小
im_ = imresize(im_,net.meta.inputSize(1:2));%此处和ImageNet网络名称不同
im_ = im_ - images.data_mean; %去均值




% res包含了计算结果,以及中间层的输出,最后一层可以用来分类,归一化处理  sofamaxloss层
res=vl_simplenn(net,im_);
y=res(end).x;
x=gather(res(end).x);
% 删除单独维度
scores=squeeze(gather(res(end).x));
[bestScore,best]=max(scores);
figure(1);
clf;
imshow(im_);
title(sprintf('%s %d,%.3f',...
        net.meta.classes.name{best-1},best-1,bestScore));
end




% 一个对序列号为60000-70000图像进行整体精度预测的代码,
% function [ net,info ] = cnn_mnist_test( varargin )
% run ../matlab/vl_setupnn
% 导入数据model
%load('E:\matconvnet-1.0-beta24\matconvnet-1.0-beta24\data\mnist-baseline-simplenn\net-epoch-20.mat');%此模型包含三个部分,其中一部分为net
% 导入数据集
% load('E:\matconvnet-1.0-beta24\matconvnet-1.0-beta24\data\mnist-baseline-simplenn\imdb.mat');%images结构体在此读取

% net = vl_simplenn_tidy(net);
% net.layers{1,end}.type = 'softmax';%训练时为softmaxloss,测试时为softmax
% % 挑选出测试样本在全体数据中对应的编号60001-70000
% test_index = find(images.set==3);%1对应训练集,3对应测试集,1有(1——60000)3有(60001——70000)
% % 挑选出测试集以及真实类别
% test_data = images.data(:,:,:,test_index);
% test_label = images.labels(test_index);

% % 将最后一层改为 softmax (原始为softmaxloss,这是训练用)
% net.layers{1, end}.type = 'softmax';

% % 对每张测试图片进行分类
% for i = 1:length(test_label)
%     i
%     im_ = test_data(:,:,:,i);
%     im_ = im_ - images.data_mean;
%     res = vl_simplenn(net, im_) ;
%     scores = squeeze(gather(res(end).x)) ;
%     [bestScore, best] = max(scores) ;
%     pre(i) = best;
% end

% % 计算准确率
% accurcy = length(find(pre==test_label))/length(test_label);
% disp(['accurcy = ',num2str(accurcy*100),'%']);
%end