DeepLearnToolbox_SAE notes

来源:互联网 发布:通讯录数据恢复 编辑:程序博客网 时间:2024/05/08 18:46


  • ex1 train a 100 hidden unit SDAE and use it to initialize a FFNN
  • ex2 train a 100-100 hidden unit SDAE and use it to initialize a FFNN
function test_example_SAE
load mnist_uint8;%ex_choise=5;%选择做第几个实验(1—6)train_x = double(train_x)/255;%60000*784,每行表示一个样本test_x  = double(test_x)/255;%10000个测试样本train_y = double(train_y);test_y  = double(test_y);

ex1 train a 100 hidden unit SDAE and use it to initialize a FFNN

Setup and train a stacked denoising autoencoder (SDAE)
rand('state',0)sae = saesetup([784 100]);%构造一个自编码器(784-100-784){1}.activation_function       = 'sigm';{1}.learningRate              = 1;{1}.inputZeroMaskedFraction   = 0.5;%  Used for Denoising AutoEncodersopts.numepochs =   1;opts.batchsize = 100;sae = saetrain(sae, train_x, opts);visualize({1}.W{1}(:,2:end)')%visualize the frist W(eg:ae{1}.W{1})% Use the SDAE to initialize a FFNNnn = nnsetup([784 100 10]);nn.activation_function              = 'sigm';nn.learningRate                     = 1;nn.W{1} ={1}.W{1};% use layer-wise unsupervise learned parameter W{1} initializes nn.W{1}% Train the FFNNopts.numepochs =   1;opts.batchsize = 100;nn = nntrain(nn, train_x, train_y, opts);[er, bad] = nntest(nn, test_x, test_y);assert(er < 0.16, 'Too big error');
Training AE 1/1epoch 1/1. Took 6.1412 seconds. Mini-batch mean squared error on training set is 10.6379; Full-batch train err = 9.644005epoch 1/1. Took 2.7784 seconds. Mini-batch mean squared error on training set is 0.21772; Full-batch train err = 0.108234

ex2 train a 100-100 hidden unit SDAE and use it to initialize a FFNN

Setup and train a stacked denoising autoencoder (SDAE)
rand('state',0)sae = saesetup([784 100 100]);%net architecture 784-100-784->100-100,stack 2{1}.activation_function       = 'sigm';{1}.learningRate              = 1;{1}.inputZeroMaskedFraction   = 0.5;% binary masking nosie{2}.activation_function       = 'sigm';{2}.learningRate              = 1;{2}.inputZeroMaskedFraction   = 0.5;opts.numepochs =   1;opts.batchsize = 100;sae = saetrain(sae, train_x, opts);visualize({1}.W{1}(:,2:end)')% Use the SDAE to initialize a FFNNnn = nnsetup([784 100 100 10]);nn.activation_function              = 'sigm';nn.learningRate                     = 1;%add pretrained weightsnn.W{1} ={1}.W{1};nn.W{2} ={2}.W{1};% Train the FFNNopts.numepochs =   1;opts.batchsize = 100;nn = nntrain(nn, train_x, train_y, opts);[er, bad] = nntest(nn, test_x, test_y);assert(er < 0.1, 'Too big error');
Training AE 1/2epoch 1/1. Took 6.4605 seconds. Mini-batch mean squared error on training set is 10.6551; Full-batch train err = 10.244430Training AE 2/2epoch 1/1. Took 1.5827 seconds. Mini-batch mean squared error on training set is 3.4285; Full-batch train err = 1.692853epoch 1/1. Took 3.4583 seconds. Mini-batch mean squared error on training set is 0.16406; Full-batch train err = 0.102337
function sae = saesetup(size)    for u = 2 : numel(size){u-1} = nnsetup([size(u-1) size(u) size(u-1)]);    endend
function sae = saetrain(sae, x, opts)    for i = 1 : numel(;%numel( 这里是个1*1 cell        disp(['Training AE ' num2str(i) '/' num2str(numel(]);{i} = nntrain({i}, x, x, opts);        t = nnff({i}, x, x);        x = t.a{2};        %remove bias term        x = x(:,2:end);    endend
function r=visualize(X, mm, s1, s2)%FROM RBMLIB weights X. If the function is called as a void method,%it does the plotting. But if the function is assigned to a variable%outside of this code, the formed image is returned instead.if ~exist('mm','var')    mm = [min(X(:)) max(X(:))];endif ~exist('s1','var')    s1 = 0;endif ~exist('s2','var')    s2 = 0;end[D,N]= size(X);s=sqrt(D);if s==floor(s) || (s1 ~=0 && s2 ~=0)    if (s1 ==0 || s2 ==0)        s1 = s; s2 = s;    end    %its a square, so data is probably an image    num=ceil(sqrt(N));    a=mm(2)*ones(num*s2+num-1,num*s1+num-1);    x=0;    y=0;    for i=1:N        im = reshape(X(:,i),s1,s2)';        a(x*s2+1+x : x*s2+s2+x, y*s1+1+y : y*s1+s1+y)=im;        x=x+1;        if(x>=num)            x=0;            y=y+1;        end    end    d=true;else    %there is not much we can do    a=X;end%return the image, or plot the imageif nargout==1    r=a;else    imagesc(a, [mm(1) mm(2)]);    axis equal    colormap grayend
0 0