【deep learning学习笔记】注释yusugomori的RBM代码 --- cpp文件 -- 模型测试

来源:互联网 发布:中国淘宝村高峰论坛 编辑:程序博客网 时间:2024/06/14 10:39

产生数据,调用上文的函数,训练RBM模型,并re-construct测试数据,用来验证训练的RBM模型。

void test_rbm() {  srand(0);  double learning_rate = 0.1;  int training_epochs = 1000;  int k = 1;    int train_N = 6;  int test_N = 2;  int n_visible = 6;  int n_hidden = 3;  // training data  int train_X[6][6] = {    {1, 1, 1, 0, 0, 0},    {1, 0, 1, 0, 0, 0},    {1, 1, 1, 0, 0, 0},    {0, 0, 1, 1, 1, 0},    {0, 0, 1, 0, 1, 0},    {0, 0, 1, 1, 1, 0}  };  // construct RBM  RBM rbm(train_N, n_visible, n_hidden, NULL, NULL, NULL);  // train  for(int epoch=0; epoch<training_epochs; epoch++)   {// iterator all the training samples, and apply the CD-k algoirthm    for(int i=0; i<train_N; i++) {      rbm.contrastive_divergence(train_X[i], learning_rate, k);    }  }  // test data  int test_X[2][6] = {    {1, 1, 0, 0, 0, 0},    {0, 0, 0, 1, 1, 0}  };  double reconstructed_X[2][6];  // test  for(int i=0; i<test_N; i++)   {    rbm.reconstruct(test_X[i], reconstructed_X[i]);    for(int j=0; j<n_visible; j++) {      printf("%.5f ", reconstructed_X[i][j]);    }    cout << endl;  }}int main() {  test_rbm();  return 0;}

运行结果如图: