Denosing Autoencoder训练过程代码详解

来源:互联网 发布:mac好用的office软件 编辑:程序博客网 时间:2024/06/07 02:27

普通deep autoencoder训练过程

本文主要参考Deeplearn toolbox中代码

matlabDeep Learning toolbox,见:https://github.com/rasmusbergpalm/DeepLearnToolbox

一:加载数据 

二:pre_training阶段

2.1初始化DAE网络框架

sae = saesetup([784 100 100]);%建立一个3层网络

在函数saesetup函数内部,循环调用nnsetup函数,此处2次调用nnsetup函数

最终saesetup([784100 100]);建立一个2个autoencoder网络

分别是:784 ——100——784;100 ——100——100

2.1.1   nnsetup函数说明;

nn.size   =architecture;   nn.n= numel(nn.size)=3;

初始化激活函数类型,学习率,动量项,稀疏项等参数。

初始化网络结构:

for i = 2 : nn.n=3   %第一次循环调用:输入nnsetup([784  100 784])

经过2次循环,初始化一个784 ——100——784 的autoencoder网络。

2.2初始化SAE网络的训练参数

初始化激活函数类型(此处默认为tanh_opt函数),学习率,噪声比例等参数。

2.3开始训练SAE网络

sae = saetrain(sae, train_x,opts);%输入网络结构,样本数据,批处理的个数

saetrain内部调用nntrain函数来训练sae的每个autoencoder子网络.

sae.ae{i} = nntrain(sae.ae{i},x, x, opts);%循环两次

2.3.1 nntrain函数说明

输入参数:nntrain(nn, train_x, train_y, opts,val_x, val_y)

由于是autoencoder网络,此处的train_y是输入数据train_x

整个数据循环训练numepochs次,每次将整个数据分成numbatches个组,进行minibatch训练。

每次minibatch训练过程:

从整个数据中,“随机提取”minibatch个batch_x和batch_y数据;

调用nnff函数前馈计算 损失函数nn.L,和误差向量 nn.e

调用nnbp函数,反向计算误差dw{i}

调用nnapplygrads函数,来更新权值

三:fine_tuning 阶段

3.1初始化前馈网络

nn = nnsetup([784 100 100 10]);

把pre_training阶段训练的权值矩阵赋值给前馈网络。

nn.W{1} = sae.ae{1}.W{1};

nn.W{2} = sae.ae{2}.W{1};

 

3.2标签数据来fine_tuning整个网络

nn = nntrain(nn, train_x,train_y, opts);

 

四:测试数据

[er, bad] = nntest(nn, test_x,test_y);

 4.1 预测样本分类

labels = nnpredict(nn, x);

nnpredict函数首先实现一个网络的前馈计算,计算样本属于每个分类的“概率”

nn = nnff(nn, x,zeros(size(x,1), nn.size(end)));

根据最大化原则,确定样本的分类,并提取类别标签

[~, i] = max(nn.a{end},[],2);%max(a,[],2)提取矩阵a每行的最大值

labels = i;

4.2计算错误率

[~, expected] = max(y,[],2);

bad = find(labels ~= expected);

er = numel(bad) / size(x, 1);

 

 

Denosing Autoencoder训练过程

Denoising autoencode主要通过在nntrain函数中对输入数据加入噪声,其他训练部分相同

一:Pre_training阶段

数据加噪:

sae.ae{1}.inputZeroMaskedFraction   = 0.5;

sae.ae{2}.inputZeroMaskedFraction   = 0.5;

由于是使用saetrain函数来分别训练每个autoencoder网络,所以这里对每个网络的输入数据都加入50%的随机噪声。

加噪代码:

if(nn.inputZeroMaskedFraction ~= 0)

batch_x =batch_x.*(rand(size(batch_x))>nn.inputZeroMaskedFraction);

end

加入50%随机噪声后,原始数据


二:fine_tuning阶段

由于此处直接使用的是nntrain,微调整个网络;所以只在在调用nntrain函数时,把原始输入数据train_x加入噪声

0 0
原创粉丝点击