【deep learning学习笔记】注释yusugomori的LR代码 --- 模型测试
来源:互联网 发布:手机淘宝店招尺寸2017 编辑:程序博客网 时间:2024/05/21 22:32
测试部分代码:
void test_lr() { srand(0); double learning_rate = 0.1; double n_epochs = 500; int train_N = 6; int test_N = 2; int n_in = 6; int n_out = 2; // int **train_X; // int **train_Y; // int **test_X; // double **test_Y; // train_X = new int*[train_N]; // train_Y = new int*[train_N]; // for(i=0; i<train_N; i++){ // train_X[i] = new int[n_in]; // train_Y[i] = new int[n_out]; // }; // test_X = new int*[test_N]; // test_Y = new double*[test_N]; // for(i=0; i<test_N; i++){ // test_X[i] = new int[n_in]; // test_Y[i] = new double[n_out]; // } // training data int train_X[6][6] = { {1, 1, 1, 0, 0, 0}, {1, 0, 1, 0, 0, 0}, {1, 1, 1, 0, 0, 0}, {0, 0, 1, 1, 1, 0}, {0, 0, 1, 1, 0, 0}, {0, 0, 1, 1, 1, 0} }; int train_Y[6][2] = { {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1} }; // construct LogisticRegression LogisticRegression classifier(train_N, n_in, n_out); // i wonder that we should set the N value to 1 as training online //LogisticRegression classifier(1, n_in, n_out); // train online for(int epoch=0; epoch<n_epochs; epoch++) { for(int i=0; i<train_N; i++) { classifier.train(train_X[i], train_Y[i], learning_rate); } // learning_rate *= 0.95; } // test data int test_X[2][6] = { {1, 0, 1, 0, 0, 0}, {0, 0, 1, 1, 1, 0} }; double test_Y[2][2]; // test for(int i=0; i<test_N; i++) { classifier.predict(test_X[i], test_Y[i]); for(int j=0; j<n_out; j++) { cout << test_Y[i][j] << " "; } cout << endl; }}int main() { test_lr(); getchar(); return 0;}测试数据实际上是在训练集合中的,分别是第二个和第四个训练数据,也就是说,这是“封闭测试”。测试结果如下所示:
不过总感觉这个调用
“
LogisticRegression classifier(train_N, n_in, n_out);”
不对。在线训练,是单个样本为单位的训练,train_N的值应该设置为1。将这一句改成
“
LogisticRegression classifier(1, n_in, n_out);”
运行结果如下:
与上面的结果差别不大。恐怕要到实际应用中检验了。
- 【deep learning学习笔记】注释yusugomori的LR代码 --- 模型测试
- 【deep learning学习笔记】注释yusugomori的RBM代码 --- cpp文件 -- 模型测试
- 【deep learning学习笔记】注释yusugomori的DA代码 --- dA.cpp -- 模型测试
- 【deep learning学习笔记】注释yusugomori的SDA代码 -- Sda.cpp -- 模型测试
- 【deep learning学习笔记】注释yusugomori的LR代码 --- LogisticRegression.h
- 【deep learning学习笔记】注释yusugomori的LR代码 --- LogisticRegression.cpp
- 【deep learning学习笔记】注释yusugomori的SDA代码 -- 准备工作
- 【deep learning学习笔记】注释yusugomori的RBM代码 --- cpp文件 -- 模型训练
- 【deep learning学习笔记】注释yusugomori的DA代码 --- dA.cpp --模型准备
- 【deep learning学习笔记】注释yusugomori的SDA代码 -- Sda.cpp -- 模型准备
- 【deep learning学习笔记】注释yusugomori的SDA代码 -- Sda.cpp -- 模型训练与预测
- 【deep learning学习笔记】注释yusugomori的RBM代码 --- 头文件
- 【deep learning学习笔记】注释yusugomori的RBM代码 --- cpp文件 -- 准备工作
- 【deep learning学习笔记】注释yusugomori的DA代码 --- dA.h
- 【deep learning学习笔记】注释yusugomori的DA代码 --- dA.cpp -- 训练
- 【deep learning学习笔记】注释yusugomori的RBM代码 --- 头文件
- 【deep learning学习笔记】注释yusugomori的SDA代码 -- Sda.h
- deep learning 学习笔记
- PostgreSQL的备份与还原
- Linux_problem 1: ntfs-3g
- 数据库模式
- IE8下,js调用另一页面js方法,不能alert问题
- java基础——异常处理
- 【deep learning学习笔记】注释yusugomori的LR代码 --- 模型测试
- 修改和删除概要文件
- jQueryMobile学习笔记二
- C# 中的委托和事件
- 关于json
- 二叉树的遍历
- UVA:10071 - Back to High School Physics
- arcgis for android 离线地图实现
- 线程类加载器