MatConvNet中mnist源码解析

来源:互联网 发布:余弦匹配算法 编辑:程序博客网 时间:2024/06/05 15:41

  本文转自:http://blog.csdn.net/u010402786

       本文的代码来自MatConvNet
  下面是自己对代码的注释:  
cnn_mnist_init.m

function net = cnn_mnist_init(varargin)% CNN_MNIST_LENET Initialize a CNN similar for MNISTopts.useBatchNorm = true ;   #batchNorm是否使用opts.networkType = 'simplenn' ;  #网络结构使用lenet结构opts = vl_argparse(opts, varargin) ;rng('default');rng(0) ;f=1/100 ;net.layers = {} ;# 定义各层参数,type是网络的层属性,stride为步长,pad为填充# method中max为最大池化net.layers{end+1} = struct('type', 'conv', ...                           'weights', {{f*randn(5,5,1,20, 'single'), zeros(1, 20, 'single')}}, ...                           'stride', 1, ...                           'pad', 0) ;net.layers{end+1} = struct('type', 'pool', ...                           'method', 'max', ...                           'pool', [2 2], ...                           'stride', 2, ...                           'pad', 0) ;net.layers{end+1} = struct('type', 'conv', ...                           'weights', {{f*randn(5,5,20,50, 'single'),zeros(1,50,'single')}}, ...                           'stride', 1, ...                           'pad', 0) ;net.layers{end+1} = struct('type', 'pool', ...                           'method', 'max', ...                           'pool', [2 2], ...                           'stride', 2, ...                           'pad', 0) ;net.layers{end+1} = struct('type', 'conv', ...                           'weights', {{f*randn(4,4,50,500, 'single'),  zeros(1,500,'single')}}, ...                           'stride', 1, ...                           'pad', 0) ;net.layers{end+1} = struct('type', 'relu') ;net.layers{end+1} = struct('type', 'conv', ...                           'weights', {{f*randn(1,1,500,10, 'single'), zeros(1,10,'single')}}, ...                           'stride', 1, ...                           'pad', 0) ;net.layers{end+1} = struct('type', 'softmaxloss') ;# optionally switch to batch normalization# BN层一般用在卷积到池化过程中,激活函数前面,这里是在第1,4,7层插入BNif opts.useBatchNorm  net = insertBnorm(net, 1) ;  net = insertBnorm(net, 4) ;  net = insertBnorm(net, 7) ;end# Meta parametersnet.meta.inputSize = [27 27 1] ;  #输入大小[w,h,channel],这里是灰度图片,单通道为1net.meta.trainOpts.learningRate = 0.001 ; #学习率net.meta.trainOpts.numEpochs = 20 ; #epoch次数,注意这里不是所谓的迭代次数net.meta.trainOpts.batchSize = 100 ; #批处理,这里就是mini-batchsize,batchSize大小对训练过程的影响见我另外一篇博客:卷积神经网络四大问题之一# Fill in defaul valuesnet = vl_simplenn_tidy(net) ;# Switch to DagNN if requested# 选择不同的网络结构,这里就使用的simplenn结构switch lower(opts.networkType)  case 'simplenn'    % done  case 'dagnn'    net = dagnn.DagNN.fromSimpleNN(net, 'canonicalNames', true) ;    net.addLayer('error', dagnn.Loss('loss', 'classerror'), ...             {'prediction','label'}, 'error') ;  otherwise    assert(false) ;end% --------------------------------------------------------------------function net = insertBnorm(net, l)   #具体的BN函数% --------------------------------------------------------------------assert(isfield(net.layers{l}, 'weights'));ndim = size(net.layers{l}.weights{1}, 4);layer = struct('type', 'bnorm', ...               'weights', {{ones(ndim, 1, 'single'), zeros(ndim, 1, 'single')}}, ...               'learningRate', [1 1 0.05], ...               'weightDecay', [0 0]) ;net.layers{l}.biases = [] ;net.layers = horzcat(net.layers(1:l), layer, net.layers(l+1:end)) ; #horzcat水平方向矩阵连接,这里就是重新构建网络结构,将BN层插入到lennt中
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81

cnn_mnist_experiments.m

%% Experiment with the cnn_mnist_fc_bnorm[net_bn, info_bn] = cnn_mnist(...  'expDir', 'data/mnist-bnorm', 'useBnorm', true);[net_fc, info_fc] = cnn_mnist(...  'expDir', 'data/mnist-baseline', 'useBnorm', false);# 以下就是画图的代码figure(1) ; clf ;subplot(1,2,1) ;  # 第一张图semilogy(info_fc.val.objective', 'o-') ; hold all ;semilogy(info_bn.val.objective', '+--') ;  #表示y坐标轴是对数坐标系xlabel('Training samples [x 10^3]'); ylabel('energy') ;grid on ; #加入网格h=legend('BSLN', 'BNORM') ;  #加入标注set(h,'color','none');title('objective') ;subplot(1,2,2) ;plot(info_fc.val.error', 'o-') ; hold all ;plot(info_bn.val.error', '+--') ;h=legend('BSLN-val','BSLN-val-5','BNORM-val','BNORM-val-5') ;grid on ;xlabel('Training samples [x 10^3]'); ylabel('error') ;set(h,'color','none') ;title('error') ;drawnow ;
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

  运行结果得到的图:
    这里写图片描述

阅读全文
0 0
原创粉丝点击