MatConvNet中mnist源码解析
来源:互联网 发布:余弦匹配算法 编辑:程序博客网 时间:2024/06/05 15:41
本文转自:http://blog.csdn.net/u010402786
本文的代码来自MatConvNet
下面是自己对代码的注释:
cnn_mnist_init.m
function net = cnn_mnist_init(varargin)% CNN_MNIST_LENET Initialize a CNN similar for MNISTopts.useBatchNorm = true ; #batchNorm是否使用opts.networkType = 'simplenn' ; #网络结构使用lenet结构opts = vl_argparse(opts, varargin) ;rng('default');rng(0) ;f=1/100 ;net.layers = {} ;# 定义各层参数,type是网络的层属性,stride为步长,pad为填充# method中max为最大池化net.layers{end+1} = struct('type', 'conv', ... 'weights', {{f*randn(5,5,1,20, 'single'), zeros(1, 20, 'single')}}, ... 'stride', 1, ... 'pad', 0) ;net.layers{end+1} = struct('type', 'pool', ... 'method', 'max', ... 'pool', [2 2], ... 'stride', 2, ... 'pad', 0) ;net.layers{end+1} = struct('type', 'conv', ... 'weights', {{f*randn(5,5,20,50, 'single'),zeros(1,50,'single')}}, ... 'stride', 1, ... 'pad', 0) ;net.layers{end+1} = struct('type', 'pool', ... 'method', 'max', ... 'pool', [2 2], ... 'stride', 2, ... 'pad', 0) ;net.layers{end+1} = struct('type', 'conv', ... 'weights', {{f*randn(4,4,50,500, 'single'), zeros(1,500,'single')}}, ... 'stride', 1, ... 'pad', 0) ;net.layers{end+1} = struct('type', 'relu') ;net.layers{end+1} = struct('type', 'conv', ... 'weights', {{f*randn(1,1,500,10, 'single'), zeros(1,10,'single')}}, ... 'stride', 1, ... 'pad', 0) ;net.layers{end+1} = struct('type', 'softmaxloss') ;# optionally switch to batch normalization# BN层一般用在卷积到池化过程中,激活函数前面,这里是在第1,4,7层插入BNif opts.useBatchNorm net = insertBnorm(net, 1) ; net = insertBnorm(net, 4) ; net = insertBnorm(net, 7) ;end# Meta parametersnet.meta.inputSize = [27 27 1] ; #输入大小[w,h,channel],这里是灰度图片,单通道为1net.meta.trainOpts.learningRate = 0.001 ; #学习率net.meta.trainOpts.numEpochs = 20 ; #epoch次数,注意这里不是所谓的迭代次数net.meta.trainOpts.batchSize = 100 ; #批处理,这里就是mini-batchsize,batchSize大小对训练过程的影响见我另外一篇博客:卷积神经网络四大问题之一# Fill in defaul valuesnet = vl_simplenn_tidy(net) ;# Switch to DagNN if requested# 选择不同的网络结构,这里就使用的simplenn结构switch lower(opts.networkType) case 'simplenn' % done case 'dagnn' net = dagnn.DagNN.fromSimpleNN(net, 'canonicalNames', true) ; net.addLayer('error', dagnn.Loss('loss', 'classerror'), ... {'prediction','label'}, 'error') ; otherwise assert(false) ;end% --------------------------------------------------------------------function net = insertBnorm(net, l) #具体的BN函数% --------------------------------------------------------------------assert(isfield(net.layers{l}, 'weights'));ndim = size(net.layers{l}.weights{1}, 4);layer = struct('type', 'bnorm', ... 'weights', {{ones(ndim, 1, 'single'), zeros(ndim, 1, 'single')}}, ... 'learningRate', [1 1 0.05], ... 'weightDecay', [0 0]) ;net.layers{l}.biases = [] ;net.layers = horzcat(net.layers(1:l), layer, net.layers(l+1:end)) ; #horzcat水平方向矩阵连接,这里就是重新构建网络结构,将BN层插入到lennt中
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
cnn_mnist_experiments.m
%% Experiment with the cnn_mnist_fc_bnorm[net_bn, info_bn] = cnn_mnist(... 'expDir', 'data/mnist-bnorm', 'useBnorm', true);[net_fc, info_fc] = cnn_mnist(... 'expDir', 'data/mnist-baseline', 'useBnorm', false);# 以下就是画图的代码figure(1) ; clf ;subplot(1,2,1) ; # 第一张图semilogy(info_fc.val.objective', 'o-') ; hold all ;semilogy(info_bn.val.objective', '+--') ; #表示y坐标轴是对数坐标系xlabel('Training samples [x 10^3]'); ylabel('energy') ;grid on ; #加入网格h=legend('BSLN', 'BNORM') ; #加入标注set(h,'color','none');title('objective') ;subplot(1,2,2) ;plot(info_fc.val.error', 'o-') ; hold all ;plot(info_bn.val.error', '+--') ;h=legend('BSLN-val','BSLN-val-5','BNORM-val','BNORM-val-5') ;grid on ;xlabel('Training samples [x 10^3]'); ylabel('error') ;set(h,'color','none') ;title('error') ;drawnow ;
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
运行结果得到的图:
阅读全文
0 0
- MatConvNet中mnist源码解析
- MatConvNet中mnist源码解析
- MatConvNet 源码解析
- MatConvNet中MNIST 数据库训练的例子
- MatConvNet框架的mnist例子
- MatConvNet 框架的mnist实例
- MatConvNet框架下mnist数据集测试
- 深度学习 9. MatConvNet 利用mnist的model来训练自己的data。MatConvNet 训练自己数据(一)。
- 解析mnist数据库
- Tensorflow之MNIST解析
- tinyxml源码解析(中)
- tinyxml源码解析(中)
- TensorFlow入门01:MNIST分类的源码及关键函数解析
- TensorFlow入门02:cnn实现MNIST分类的源码及关键函数解析
- Matconvnet学习——利用mnist网络训练自己的数据分辨左右手
- StudyAI上MatConvNet框架学习笔记之3:mnist实例代码分析
- win7下matlab 中安装 matconvnet
- matconvnet中使用fastrcnn遇到的问题
- SCI信件回复
- stack的模拟实现
- Invalid <url-pattern> /*.action in filter mapping错误以及(Servlet和Filter的url匹配url-p)
- 关于 epoch、 iteration和batchsize ,关于batchsize
- Java中HashMap的工作原理
- MatConvNet中mnist源码解析
- destoon 常量与变量
- 也来玩玩反编译
- repo问题
- 【资讯】福布斯:旅行积分计划是区块链主要目标,对旅行者来说是好消息
- 【国际】印度央行研究“法定加密货币”作为数字卢比
- 【资讯】卡巴斯基实验室:165万台用户电脑受到加密货币挖矿恶意软件攻击
- CentOS克隆机器步骤,图文教程
- 【国际】费城联邦储备银行会议探索区块链对金融稳定的影响