DeepLearning(深度学习)原理与实现(四)
来源:互联网 发布:windows丢失无法启动 编辑:程序博客网 时间:2024/05/14 12:10
今天就来讨论下deep learning怎么来处理real valued data。对于图像来说,二值图像毕竟是少数,更多的还是实值图像。对于这样的情况,RBM已经无法很好的处理它们,因此需要改进它,对于了解计算机视觉的人而言,想必高斯混合背景模型大家已不陌生,高斯混合模型可以很好的对实值图像建模,OpenCV中早就用高斯混合背景模型来分割物体。接下来要引出的高斯有限制玻尔兹曼机(Gaussian Restricted Boltzmann machine-GRBM)和高斯混合模型有一定的等价性。
在上一节中,作者提到过对于RBM这个无向图模型而言,无论是给定隐藏节点还是可视节点,剩下的那层内节点之间互相独立。那么对于RBM这个图模型描述的联合分布可以分解成(公式一)的形式:
(公式一)
其中P(h)可看成高斯混合系数,如果P(v|h)是高斯分布,那么这个条件下的RBM就和高斯混合分布具有等价性。因此GRBM合理登场 ,下面开始进入GRBM正题,首先当然要假定P(v|h)服从高斯分布,不然没法进展下去咯,先来看下GRBM的能量公式定义(公式二):
从能量公式中可以很明显的看出有高斯分布的影子,当然还有其他类型的能量公式,关键看你做什么假设,假设不一样,能量公式也不一样,假设的好就会产生一篇好文章哦。其中参数要求的为:。求取算法和此系列文章一方法类似,也是用CD算法,但是对于GRBM的CD算法,需要多啰嗦几句,首先我们把博文一中的求取数据期望项和模型期望项的过程改成官方称呼:分别对应为正阶段(positive phase)和负阶段(negative phase)。
正阶段的目的就是找到给定数据的情况下,隐藏节点的配置(暂称它为"好配置"),并且使得能量降低。负阶段的目的当然就是抬升"好配置"周围的配置的能量,可能有点拗口,推荐看下博文一中的参考文献:Energy based model tutorial,可以把它总结成一句话:把和实际样本有关的配置能量降低,把fancy的样本的能量抬高,这样得到的权重是对样本的最佳描述,当然样本越多越好,所以大数据这么火。有些人会有疑问,样本多会不会导致过拟合?其实生成模型的魅力之一就在这,样本越多越好。正阶段中隐藏节点h的状态求取仍然用sigmoid函数得出概率,然后采样;但是负阶段中可视节点V的状态的概率就不是用sigmoid函数了,而是用高斯函数,二者的公式分别如(公式三)和(公式四)表示:
(公式三,f为sigmoid函数)
(公式四,高斯分布函数)
对于参数的求取,过程只是多了sigma而已,如果你熟悉高斯混合模型,自然对sigma的实际作用也有些直观认识,庆幸的是sigma仍然也可以在CD算法中与其他参数一样同时求出来。下面给出GRBM中CD算法的部分代码(matlab),其中invfstdInc为sigma,不要用用这段代码,而是用它解释自己对GRBM实现细节的疑问,这段代码中,CD算法引入了momentum(冲量)技巧,可以让算法收敛快些:
%%%%%%%%% START POSITIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% data = batchdata(:,:,batch); %nxd pos_hidprobs = 1./(1 + exp(-(data./Fstd)*vhW - Ones*hb)); %p(h_j =1|data) pos_hidstates = pos_hidprobs > rand( size(pos_hidprobs) ); pos_prods = (data./Fstd)'* pos_hidprobs; pos_hid_act = sum(pos_hidprobs); pos_vis_act = sum(data)./(fstd.^2); %see notes on this derivation %%%%%%%%% END OF POSITIVE PHASE %%%%%%%%% for iterCD = 1:params.nCD %%%%%%%%% START NEGATIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% negdataprobs = pos_hidstates*vhW'.*Fstd+Ones*vb; negdata = negdataprobs + randn(n, d).*Fstd; neg_hidprobs = 1./(1 + exp(-(negdata./Fstd)*vhW - Ones*hb )); %updating hidden nodes again pos_hidstates = neg_hidprobs > rand( size(neg_hidprobs) ); end %end CD iterations neg_prods = (negdata./Fstd)'*neg_hidprobs; neg_hid_act = sum(neg_hidprobs); neg_vis_act = sum(negdata)./(fstd.^2); %see notes for details %%%%%%%%% END OF NEGATIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% errsum = errsum + sum(sum( (data-negdata).^2 )); if epoch > params.init_final_momen_iter, momentum=params.final_momen; else momentum=params.init_momen; end %%%%%%%%% UPDATE WEIGHTS AND BIASES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% vhWInc = momentum*vhWInc + r/n*(pos_prods-neg_prods) - r*params.wtcost*vhW; vbInc = momentum*vbInc + (r/n)*(pos_vis_act-neg_vis_act); hbInc = momentum*hbInc + (r/n)*(pos_hid_act-neg_hid_act); invfstd_grad = sum(2*data.*(Ones*vb-data/2)./Fstd,1) + sum(data' .* (vhW*pos_hidprobs') ,2)'; invfstd_grad = invfstd_grad - ( sum(2*negdata.*(Ones*vb-negdata/2)./Fstd,1) + ... sum( negdata'.*(vhW*neg_hidprobs') ,2 )' ); invfstdInc = momentum*invfstdInc + std_rate(epoch)/n*invfstd_grad; if params.SPARSE == 1 %nair's paper on 3D object recognition %update q if batch==1 && epoch == 1 q = mean(pos_hidprobs); else q_prev = q; q = 0.9*q_prev+0.1*mean(pos_hidprobs); end p = params.sparse_p; grad = 0.1*params.sparse_lambda/n*sum(pos_hidprobs.*(1-pos_hidprobs)).*(p-q)./(q.*(1-q)); gradW =0.1*params.sparse_lambda/n*(data'./Fstd'*(pos_hidprobs.*(1-pos_hidprobs))).*repmat((p-q)./(q.*(1-q)), d,1); hbInc = hbInc + r*grad; vhWInc = vhWInc + r*gradW; end ptot = ptot+mean(pos_hidprobs(:)); vhW = vhW + vhWInc; vb = vb + vbInc; hb = hb + hbInc; invfstd = 1./fstd; invfstd = invfstd + invfstdInc; fstd = 1./invfstd; fstd = max(fstd, 0.005); %have a lower bound!
参考文献:
Guassian-Bernoulli Deep Boltzmann Machine. KyungHyun Cho
- DeepLearning(深度学习)原理与实现(四)
- DeepLearning(深度学习)原理与实现(四)
- DeepLearning(深度学习)原理与实现
- DeepLearning(深度学习)原理与实现(一)
- DeepLearning(深度学习)原理与实现(二)
- DeepLearning(深度学习)原理与实现(三)
- DeepLearning(深度学习)原理与实现(五)
- DeepLearning(深度学习)原理与实现(一)
- DeepLearning(深度学习)原理与实现(二)
- DeepLearning(深度学习)原理与实现(三)
- DeepLearning(深度学习)原理与实现(五)
- DeepLearning(深度学习)原理与实现(一)
- DeepLearning(深度学习)原理与实现(一)
- DeepLearning(深度学习)原理与实现(一)
- DeepLearning(深度学习)原理与实现(一)
- DeepLearning(深度学习)原理与实现(五)
- DeepLearning(深度学习)原理与实现(一)
- 《深度学习原理与TensorFlow实践》学习笔记(四)
- 单例模式之窗体应用——“唯一”
- MATLAB编译cpp文件
- 留言板完善与问题总结
- [gotoac]强连通+缩点
- suse linux 11 连接 Xmanager
- DeepLearning(深度学习)原理与实现(四)
- Android GridView
- Linux下安装Apache PHP MySql Memcached
- Understanding Caching in Hibernate – Part Two : The Query Cache
- ubuntu apt-get安装和卸载程序
- Asp.Net使用到的控件和技术方法记录!
- JS控制只能往输入框中输入数字
- WPF动态添加XAML
- linux时钟浅析