DeepLearnToolbox_DBN notes
来源:互联网 发布:五十岁而知天命 编辑:程序博客网 时间:2024/05/29 10:53
Contents
- ex1 train a 100 hidden unit RBM and visualize its weights
- ex2 train a 100-100 hidden unit DBN and use its weights to initialize a NN
function test_example_DBN
load mnist_uint8;%数据归一化train_x = double(train_x) / 255;test_x = double(test_x) / 255;train_y = double(train_y);test_y = double(test_y);
ex1 train a 100 hidden unit RBM and visualize its weights
rand('state',0)dbn.sizes = [100];%隐层设置为100个节点opts.numepochs = 1;opts.batchsize = 100;opts.momentum = 0;opts.alpha = 1;dbn = dbnsetup(dbn, train_x, opts);dbn = dbntrain(dbn, train_x, opts);figure; visualize(dbn.rbm{1}.W'); % Visualize the RBM weights
epoch 1/1. Average reconstruction error is: 66.2661
ex2 train a 100-100 hidden unit DBN and use its weights to initialize a NN
rand('state',0)%train dbndbn.sizes = [100 100];opts.numepochs = 1;opts.batchsize = 100;opts.momentum = 0;opts.alpha = 1;dbn = dbnsetup(dbn, train_x, opts);dbn = dbntrain(dbn, train_x, opts);%unfold dbn to nnnn = dbnunfoldtonn(dbn, 10);nn.activation_function = 'sigm';%train nnopts.numepochs = 1;opts.batchsize = 100;nn = nntrain(nn, train_x, train_y, opts);[er, bad] = nntest(nn, test_x, test_y);assert(er < 0.10, 'Too big error');
epoch 1/1. Average reconstruction error is: 66.2661epoch 1/1. Average reconstruction error is: 10.286epoch 1/1. Took 3.4378 seconds. Mini-batch mean squared error on training set is 0.16201; Full-batch train err = 0.089004
function dbn = dbnsetup(dbn, x, opts) n = size(x, 2);%784维的输入对应可视层 dbn.sizes = [n, dbn.sizes];%784*100 for u = 1 : numel(dbn.sizes) - 1 %有几个权值阵或几个rbm dbn.rbm{u}.alpha = opts.alpha; dbn.rbm{u}.momentum = opts.momentum; dbn.rbm{u}.W = zeros(dbn.sizes(u + 1), dbn.sizes(u));%100*784 dbn.rbm{u}.vW = zeros(dbn.sizes(u + 1), dbn.sizes(u)); dbn.rbm{u}.b = zeros(dbn.sizes(u), 1);%可视层偏置 784*1 dbn.rbm{u}.vb = zeros(dbn.sizes(u), 1); dbn.rbm{u}.c = zeros(dbn.sizes(u + 1), 1);%隐层偏置,100*1 dbn.rbm{u}.vc = zeros(dbn.sizes(u + 1), 1); endendfunction dbn = dbntrain(dbn, x, opts) n = numel(dbn.rbm);%几个rbm dbn.rbm{1} = rbmtrain(dbn.rbm{1}, x, opts); for i = 2 : n x = rbmup(dbn.rbm{i - 1}, x); dbn.rbm{i} = rbmtrain(dbn.rbm{i}, x, opts); endendfunction rbm = rbmtrain(rbm, x, opts) assert(isfloat(x), 'x must be a float'); m = size(x, 1);%样本数 numbatches = m / opts.batchsize;%bacth数量 assert(rem(numbatches, 1) == 0, 'numbatches not integer'); for i = 1 : opts.numepochs%迭代周期 kk = randperm(m); err = 0; for l = 1 : numbatches batch = x(kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize), :);%随机不重复取出一个batch v1 = batch;%初始化V1 %gibbs sampling h1 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v1 * rbm.W'); v2 = sigmrnd(repmat(rbm.b', opts.batchsize, 1) + h1 * rbm.W); h2 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v2 * rbm.W'); c1 = h1' * v1; c2 = h2' * v2; rbm.vW = rbm.momentum * rbm.vW + rbm.alpha * (c1 - c2) / opts.batchsize; rbm.vb = rbm.momentum * rbm.vb + rbm.alpha * sum(v1 - v2)' / opts.batchsize; rbm.vc = rbm.momentum * rbm.vc + rbm.alpha * sum(h1 - h2)' / opts.batchsize; rbm.W = rbm.W + rbm.vW; rbm.b = rbm.b + rbm.vb; rbm.c = rbm.c + rbm.vc; err = err + sum(sum((v1 - v2) .^ 2)) / opts.batchsize; end disp(['epoch ' num2str(i) '/' num2str(opts.numepochs) '. Average reconstruction error is: ' num2str(err / numbatches)]); endendfunction x = rbmup(rbm, x) x = sigm(repmat(rbm.c', size(x, 1), 1) + x * rbm.W');endfunction nn = dbnunfoldtonn(dbn, outputsize)%DBNUNFOLDTONN Unfolds a DBN to a NN% dbnunfoldtonn(dbn, outputsize ) returns the unfolded dbn with a final% layer of size outputsize added. if(exist('outputsize','var')) size = [dbn.sizes outputsize]; else size = [dbn.sizes]; end nn = nnsetup(size); for i = 1 : numel(dbn.rbm) nn.W{i} = [dbn.rbm{i}.c dbn.rbm{i}.W]; endend
0 0
- DeepLearnToolbox_DBN notes
- Notes
- notes
- Notes
- notes
- notes
- notes
- notes
- Notes
- notes
- Notes
- notes
- Notes
- notes
- Notes
- Notes
- Notes
- notes
- 轻松搞定面试中的二叉树题目
- ActiveRecord 验证及回调函数callback
- 现在优化软件概述--引自浙大江爱朋博士论文
- Java笔记
- 跟踪Oracle启动状态
- DeepLearnToolbox_DBN notes
- rails 单表继承 观察者
- 2013视觉跟踪的初学者
- SWT、Swing 或 AWT:哪个更适合您?
- 《人月神话》读书笔记
- 通过继承CWinThread实现MFC多线程
- scrollview 中嵌套 edittext edit无法获取焦点的解决办法
- window下NDK详细配置及如何编译 .
- 编译ndk可执行程序 直接用cygwin编译 不用eclipse