DeepLearning(深度学习)原理与实现(二)

来源:互联网 发布:淘宝网的古悦堂怎么样 编辑:程序博客网 时间:2024/05/28 05:17

下面贴出RBM C++版本的代码,一些大牛写的,结合上篇博文来加深大家对RBM理论的理解。。。


RBM类定义声明:

[cpp] view plaincopyprint?
  1. class RBM {  
  2.   
  3. public:  
  4.   int N;  
  5.   int n_visible;  
  6.   int n_hidden;  
  7.   double **W;  
  8.   double *hbias;  
  9.   double *vbias;  
  10.   RBM(intintintdouble**, double*, double*);  
  11.   ~RBM();  
  12.   void contrastive_divergence(int*, doubleint);  
  13.   void sample_h_given_v(int*, double*, int*);  
  14.   void sample_v_given_h(int*, double*, int*);  
  15.   double propup(int*, double*, double);  
  16.   double propdown(int*, intdouble);  
  17.   void gibbs_hvh(int*, double*, int*, double*, int*);  
  18.   void reconstruct(int*, double*);  
  19. };  

从上面声明中可以很直观的看出和上篇文章公式符号正好完美对应。下面是代码实现部分:
[cpp] view plaincopyprint?
  1.   
[cpp] view plaincopyprint?
  1. #include <iostream>   
  2. #include <math.h>   
  3. #include "RBM.h"   
  4. using namespace std;  
  5.   
  6. double uniform(double min, double max) {  
  7.   return rand() / (RAND_MAX + 1.0) * (max - min) + min;  
  8. }  
  9.   
  10. int binomial(int n, double p) {  
  11.   if(p < 0 || p > 1) return 0;  
  12.     
  13.   int c = 0;  
  14.   double r;  
  15.     
  16.   for(int i=0; i<n; i++) {  
  17.     r = rand() / (RAND_MAX + 1.0);  
  18.     if (r < p) c++;  
  19.   }  
  20.   
  21.   return c;  
  22. }  
  23.   
  24. double sigmoid(double x) {  
  25.   return 1.0 / (1.0 + exp(-x));  
  26. }  
  27.   
  28.   
  29. RBM::RBM(int size, int n_v, int n_h, double **w, double *hb, double *vb) {  
  30.   N = size;  
  31.   n_visible = n_v;  
  32.   n_hidden = n_h;  
  33.   
  34.   if(w == NULL) {  
  35.     W = new double*[n_hidden];  
  36.     for(int i=0; i<n_hidden; i++) W[i] = new double[n_visible];  
  37.     double a = 1.0 / n_visible;  
  38.   
  39.     for(int i=0; i<n_hidden; i++) {  
  40.       for(int j=0; j<n_visible; j++) {  
  41.         W[i][j] = uniform(-a, a);  
  42.       }  
  43.     }  
  44.   } else {  
  45.     W = w;  
  46.   }  
  47.   
  48.   if(hb == NULL) {  
  49.     hbias = new double[n_hidden];  
  50.     for(int i=0; i<n_hidden; i++) hbias[i] = 0;  
  51.   } else {  
  52.     hbias = hb;  
  53.   }  
  54.   
  55.   if(vb == NULL) {  
  56.     vbias = new double[n_visible];  
  57.     for(int i=0; i<n_visible; i++) vbias[i] = 0;  
  58.   } else {  
  59.     vbias = vb;  
  60.   }  
  61. }  
  62.   
  63. RBM::~RBM() {  
  64.   for(int i=0; i<n_hidden; i++) delete[] W[i];  
  65.   delete[] W;  
  66.   delete[] hbias;  
  67.   delete[] vbias;  
  68. }  
  69.   
  70.   
  71. void RBM::contrastive_divergence(int *input, double lr, int k) {  
  72.   double *ph_mean = new double[n_hidden];  
  73.   int *ph_sample = new int[n_hidden];  
  74.   double *nv_means = new double[n_visible];  
  75.   int *nv_samples = new int[n_visible];  
  76.   double *nh_means = new double[n_hidden];  
  77.   int *nh_samples = new int[n_hidden];  
  78.   
  79.   /* CD-k */  
  80.   sample_h_given_v(input, ph_mean, ph_sample);  
  81.   
  82.   for(int step=0; step<k; step++) {  
  83.     if(step == 0) {  
  84.       gibbs_hvh(ph_sample, nv_means, nv_samples, nh_means, nh_samples);  
  85.     } else {  
  86.       gibbs_hvh(nh_samples, nv_means, nv_samples, nh_means, nh_samples);  
  87.     }  
  88.   }  
  89.   
  90.   for(int i=0; i<n_hidden; i++) {  
  91.     for(int j=0; j<n_visible; j++) {  
  92.       W[i][j] += lr * (ph_sample[i] * input[j] - nh_means[i] * nv_samples[j]) / N;  
  93.     }  
  94.     hbias[i] += lr * (ph_sample[i] - nh_means[i]) / N;  
  95.   }  
  96.   
  97.   for(int i=0; i<n_visible; i++) {  
  98.     vbias[i] += lr * (input[i] - nv_samples[i]) / N;  
  99.   }  
  100.   
  101.   delete[] ph_mean;  
  102.   delete[] ph_sample;  
  103.   delete[] nv_means;  
  104.   delete[] nv_samples;  
  105.   delete[] nh_means;  
  106.   delete[] nh_samples;  
  107. }  
  108.   
  109. void RBM::sample_h_given_v(int *v0_sample, double *mean, int *sample) {  
  110.   for(int i=0; i<n_hidden; i++) {  
  111.     mean[i] = propup(v0_sample, W[i], hbias[i]);  
  112.     sample[i] = binomial(1, mean[i]);  
  113.   }  
  114. }  
  115.   
  116. void RBM::sample_v_given_h(int *h0_sample, double *mean, int *sample) {  
  117.   for(int i=0; i<n_visible; i++) {  
  118.     mean[i] = propdown(h0_sample, i, vbias[i]);  
  119.     sample[i] = binomial(1, mean[i]);  
  120.   }  
  121. }  
  122.   
  123. double RBM::propup(int *v, double *w, double b) {  
  124.   double pre_sigmoid_activation = 0.0;  
  125.   for(int j=0; j<n_visible; j++) {  
  126.     pre_sigmoid_activation += w[j] * v[j];  
  127.   }  
  128.   pre_sigmoid_activation += b;  
  129.   return sigmoid(pre_sigmoid_activation);  
  130. }  
  131.   
  132. double RBM::propdown(int *h, int i, double b) {  
  133.   double pre_sigmoid_activation = 0.0;  
  134.   for(int j=0; j<n_hidden; j++) {  
  135.     pre_sigmoid_activation += W[j][i] * h[j];  
  136.   }  
  137.   pre_sigmoid_activation += b;  
  138.   return sigmoid(pre_sigmoid_activation);  
  139. }  
  140.   
  141. void RBM::gibbs_hvh(int *h0_sample, double *nv_means, int *nv_samples, \  
  142.                     double *nh_means, int *nh_samples) {  
  143.   sample_v_given_h(h0_sample, nv_means, nv_samples);  
  144.   sample_h_given_v(nv_samples, nh_means, nh_samples);  
  145. }  
  146.   
  147. void RBM::reconstruct(int *v, double *reconstructed_v) {  
  148.   double *h = new double[n_hidden];  
  149.   double pre_sigmoid_activation;  
  150.   
  151.   for(int i=0; i<n_hidden; i++) {  
  152.     h[i] = propup(v, W[i], hbias[i]);  
  153.   }  
  154.   
  155.   for(int i=0; i<n_visible; i++) {  
  156.     pre_sigmoid_activation = 0.0;  
  157.     for(int j=0; j<n_hidden; j++) {  
  158.       pre_sigmoid_activation += W[j][i] * h[j];  
  159.     }  
  160.     pre_sigmoid_activation += vbias[i];  
  161.   
  162.     reconstructed_v[i] = sigmoid(pre_sigmoid_activation);  
  163.   }  
  164.   
  165.   delete[] h;  
  166. }  
  167.   
  168.   
  169. void test_rbm() {  
  170.   srand(0);  
  171.   
  172.   double learning_rate = 0.1;  
  173.   int training_epochs = 1000;  
  174.   int k = 1;  
  175.     
  176.   int train_N = 6;  
  177.   int test_N = 2;  
  178.   int n_visible = 6;  
  179.   int n_hidden = 3;  
  180.   
  181.   // training data   
  182.   int train_X[6][6] = {  
  183.     {1, 1, 1, 0, 0, 0},  
  184.     {1, 0, 1, 0, 0, 0},  
  185.     {1, 1, 1, 0, 0, 0},  
  186.     {0, 0, 1, 1, 1, 0},  
  187.     {0, 0, 1, 0, 1, 0},  
  188.     {0, 0, 1, 1, 1, 0}  
  189.   };  
  190.   
  191.   
  192.   // construct RBM   
  193.   RBM rbm(train_N, n_visible, n_hidden, NULL, NULL, NULL);  
  194.   
  195.   // train   
  196.   for(int epoch=0; epoch<training_epochs; epoch++) {  
  197.     for(int i=0; i<train_N; i++) {  
  198.       rbm.contrastive_divergence(train_X[i], learning_rate, k);  
  199.     }  
  200.   }  
  201.   
  202.   // test data   
  203.   int test_X[2][6] = {  
  204.     {1, 1, 0, 0, 0, 0},  
  205.     {0, 0, 0, 1, 1, 0}  
  206.   };  
  207.   double reconstructed_X[2][6];  
  208.   
  209.   
  210.   // test   
  211.   for(int i=0; i<test_N; i++) {  
  212.     rbm.reconstruct(test_X[i], reconstructed_X[i]);  
  213.     for(int j=0; j<n_visible; j++) {  
  214.       printf("%.5f ", reconstructed_X[i][j]);  
  215.     }  
  216.     cout << endl;  
  217.   }  
  218.   
  219. }  
  220.   
  221.   
  222.   
  223. int main() {  
  224.   test_rbm();  
  225.   return 0;  
  226. }  


干脆把运行结果也贴出来,给那些终极极品思考者提供一些方便偷笑

0.98472  0.67248  0.99120  0.01000  0.01311  0.01020
0.01021  0.00720  0.99525  0.65553  0.98403  0.00497


转载请注明出处:http://blog.csdn.net/cuoqu/article/details/8887882


原创粉丝点击