DeepLearnToolbox_DBN notes

来源:互联网 发布:五十岁而知天命 编辑:程序博客网 时间:2024/05/29 10:53

Contents

  • ex1 train a 100 hidden unit RBM and visualize its weights
  • ex2 train a 100-100 hidden unit DBN and use its weights to initialize a NN
function test_example_DBN
load mnist_uint8;%数据归一化train_x = double(train_x) / 255;test_x  = double(test_x)  / 255;train_y = double(train_y);test_y  = double(test_y);

ex1 train a 100 hidden unit RBM and visualize its weights

rand('state',0)dbn.sizes = [100];%隐层设置为100个节点opts.numepochs =   1;opts.batchsize = 100;opts.momentum  =   0;opts.alpha     =   1;dbn = dbnsetup(dbn, train_x, opts);dbn = dbntrain(dbn, train_x, opts);figure; visualize(dbn.rbm{1}.W');   %  Visualize the RBM weights
epoch 1/1. Average reconstruction error is: 66.2661

ex2 train a 100-100 hidden unit DBN and use its weights to initialize a NN

rand('state',0)%train dbndbn.sizes = [100 100];opts.numepochs =   1;opts.batchsize = 100;opts.momentum  =   0;opts.alpha     =   1;dbn = dbnsetup(dbn, train_x, opts);dbn = dbntrain(dbn, train_x, opts);%unfold dbn to nnnn = dbnunfoldtonn(dbn, 10);nn.activation_function = 'sigm';%train nnopts.numepochs =  1;opts.batchsize = 100;nn = nntrain(nn, train_x, train_y, opts);[er, bad] = nntest(nn, test_x, test_y);assert(er < 0.10, 'Too big error');
epoch 1/1. Average reconstruction error is: 66.2661epoch 1/1. Average reconstruction error is: 10.286epoch 1/1. Took 3.4378 seconds. Mini-batch mean squared error on training set is 0.16201; Full-batch train err = 0.089004
function dbn = dbnsetup(dbn, x, opts)    n = size(x, 2);%784维的输入对应可视层    dbn.sizes = [n, dbn.sizes];%784*100    for u = 1 : numel(dbn.sizes) - 1 %有几个权值阵或几个rbm        dbn.rbm{u}.alpha    = opts.alpha;        dbn.rbm{u}.momentum = opts.momentum;        dbn.rbm{u}.W  = zeros(dbn.sizes(u + 1), dbn.sizes(u));%100*784        dbn.rbm{u}.vW = zeros(dbn.sizes(u + 1), dbn.sizes(u));        dbn.rbm{u}.b  = zeros(dbn.sizes(u), 1);%可视层偏置 784*1        dbn.rbm{u}.vb = zeros(dbn.sizes(u), 1);        dbn.rbm{u}.c  = zeros(dbn.sizes(u + 1), 1);%隐层偏置,100*1        dbn.rbm{u}.vc = zeros(dbn.sizes(u + 1), 1);    endend
function dbn = dbntrain(dbn, x, opts)    n = numel(dbn.rbm);%几个rbm    dbn.rbm{1} = rbmtrain(dbn.rbm{1}, x, opts);    for i = 2 : n        x = rbmup(dbn.rbm{i - 1}, x);        dbn.rbm{i} = rbmtrain(dbn.rbm{i}, x, opts);    endend
function rbm = rbmtrain(rbm, x, opts)    assert(isfloat(x), 'x must be a float');    m = size(x, 1);%样本数    numbatches = m / opts.batchsize;%bacth数量    assert(rem(numbatches, 1) == 0, 'numbatches not integer');    for i = 1 : opts.numepochs%迭代周期        kk = randperm(m);        err = 0;        for l = 1 : numbatches            batch = x(kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize), :);%随机不重复取出一个batch            v1 = batch;%初始化V1                          %gibbs sampling            h1 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v1 * rbm.W');            v2 = sigmrnd(repmat(rbm.b', opts.batchsize, 1) + h1 * rbm.W);            h2 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v2 * rbm.W');            c1 = h1' * v1;            c2 = h2' * v2;            rbm.vW = rbm.momentum * rbm.vW + rbm.alpha * (c1 - c2)     / opts.batchsize;            rbm.vb = rbm.momentum * rbm.vb + rbm.alpha * sum(v1 - v2)' / opts.batchsize;            rbm.vc = rbm.momentum * rbm.vc + rbm.alpha * sum(h1 - h2)' / opts.batchsize;            rbm.W = rbm.W + rbm.vW;            rbm.b = rbm.b + rbm.vb;            rbm.c = rbm.c + rbm.vc;            err = err + sum(sum((v1 - v2) .^ 2)) / opts.batchsize;        end        disp(['epoch ' num2str(i) '/' num2str(opts.numepochs)  '. Average reconstruction error is: ' num2str(err / numbatches)]);    endend
function x = rbmup(rbm, x)    x = sigm(repmat(rbm.c', size(x, 1), 1) + x * rbm.W');end
function nn = dbnunfoldtonn(dbn, outputsize)%DBNUNFOLDTONN Unfolds a DBN to a NN%   dbnunfoldtonn(dbn, outputsize ) returns the unfolded dbn with a final%   layer of size outputsize added.    if(exist('outputsize','var'))        size = [dbn.sizes outputsize];    else        size = [dbn.sizes];    end    nn = nnsetup(size);    for i = 1 : numel(dbn.rbm)        nn.W{i} = [dbn.rbm{i}.c dbn.rbm{i}.W];    endend
0 0