OpenCV的machine learning模块使用

来源:互联网 发布:零点微网络是什么 编辑:程序博客网 时间:2024/06/08 08:02

opencv中提供的了较为完善的machine learning 模块,包含多种ml算法,极大了简化了实验过程。然而目前网上大部分的资料(包括官方文档)中关于ml模块的使用均是针对1.0风格的旧代码的,这对我们的学习造成了极大的困扰。本文将简单介绍一下如何使用opencv的ml模块进行实验。 
首先,准备实验数据,我这里使用的是《模式分类》一书中第二章上机习题的部分数据,旨在进行一个简单的调用过程进行实验。实验数据如下表所示,在实际实验过程中,使用txt文档保存数据,并且没有文件头信息(实际上opencv提供了从csv文档读取数据的功能,这里简化实验没有使用该函数)。

类别特征1特征2特征31-5.01-8.12-3.681-5.43-3.48-3.5411.08-5.521.6610.86-3.78-4.111-2.670.637.3914.943.292.081-2.512.09-2.591-2.25-2.13-6.9415.562.86-2.2611.03-3.334.33-1-0.91-0.18-0.05-11.3-2.06-3.53-1-7.75-4.54-0.95-1-5.470.53.92-16.145.72-4.85-13.61.264.63-15.37-4.63-3.65-17.181.46-6.66-1-7.391.176.3-1-7.5-6.32-0.31

实验思路如下:

  • 读取数据,并构造训练样本的特征矩阵,标记矩阵(这里使用1和-1进行标记);
  • 使用合适的分类器进行训练;
  • 使用训练好的分类器进行分类(这里直接使用训练样本进行测试,便于直观看出测试结果,同时简化实验)

实验代码如下所示:

  1. #include <opencv2/core.hpp>
  2. #include <opencv2/core/utility.hpp>
  3. #include <opencv2/highgui.hpp>
  4. #include <opencv2/ml.hpp>
  5. #include <iostream>
  6. #include <fstream>
  7. #include <vector>
  8. using namespace std;
  9. using namespace cv;
  10. using namespace cv::ml;
  11. void svm_classifier(Mat &training_data_mat, Mat &label_mat);
  12. void bayes_classifier(Mat &training_data_mat, Mat &label_mat);
  13. int main()
  14. {
  15. const int class_num = 2;
  16. const int feature_num = 3;
  17. ifstream file("E:\\programs\\Dec\\bayes_opencv\\data.txt");
  18. float value;
  19. vector<float> data_vec;
  20. while (!file.eof())
  21. {
  22. file >> value;
  23. data_vec.push_back(value);
  24. }
  25. Mat data(data_vec);
  26. data = data.reshape(0, 20);
  27. Mat training_data_mat = data.colRange(1, data.cols);
  28. Mat lable_mat(data.col(0));
  29. lable_mat.convertTo(lable_mat, CV_32SC1);
  30. cout << "bayes classifier" << endl;
  31. bayes_classifier(training_data_mat, lable_mat);
  32. cout << "svm classifier" << endl;
  33. svm_classifier(training_data_mat, lable_mat);
  34. return 0;
  35. }
  36. void svm_classifier(Mat &training_data_mat, Mat &lable_mat)
  37. {
  38. SVM::Params params;
  39. params.svmType = SVM::C_SVC;
  40. params.kernelType = SVM::LINEAR;
  41. params.termCrit = TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6);
  42. Ptr<SVM> svm = StatModel::train<SVM>(training_data_mat, ROW_SAMPLE, lable_mat, params);
  43. for (size_t i = 0; i != training_data_mat.rows; ++i)
  44. {
  45. Mat test_mat = training_data_mat.row(i);
  46. float response = svm->predict(test_mat);
  47. cout << i + 1 << ":\t" << response << flush;
  48. MatIterator_<float> it, end;
  49. for (it = test_mat.begin<float>(), end = test_mat.end<float>(); it != end; ++it)
  50. {
  51. cout << '\t' << *it << flush;
  52. }
  53. cout << endl;
  54. }
  55. }
  56. void bayes_classifier(Mat &training_data_mat, Mat &lable_mat)
  57. {
  58. NormalBayesClassifier::Params params;
  59. Ptr<TrainData> train_data = TrainData::create(training_data_mat, ROW_SAMPLE, lable_mat);
  60. Ptr<NormalBayesClassifier> bayes = StatModel::train<NormalBayesClassifier>(train_data, params, 0);
  61. for (size_t i = 0; i != training_data_mat.rows; ++i)
  62. {
  63. Mat test_mat = training_data_mat.row(i);
  64. float response = bayes->predict(test_mat);
  65. cout << i + 1 << ":\t" << response << flush;
  66. MatIterator_<float> it, end;
  67. for (it = test_mat.begin<float>(), end = test_mat.end<float>(); it != end; ++it)
  68. {
  69. cout << '\t' << *it << flush;
  70. }
  71. cout << endl;
  72. }
  73. }

代码运行结果如图所示: 
opencv贝叶斯、svm分类结果 
从上述代码可以看出,实验使用了SVM和normal baysian classifier两个分类器,两分类器具有相同结构的使用方式,不同的是SVM使用前需要设置合适的参数,而贝叶斯分类器不需要。实验主要使用了两个函数:

  1. //训练
  2. Ptr<SVM> svm = StatModel::train<SVM>(training_data_mat, ROW_SAMPLE, lable_mat, params);
  3. Ptr<NormalBayesClassifier> bayes = StatModel::train<NormalBayesClassifier>(train_data, params, 0);
  4. //测试
  5. float response = svm->predict(test_mat);
  6. float response = bayes->predict(test_mat);

这里训练svm和bayes时使用了两个不同的函数,这是因为分类器类SVM和NormalBayesClassifier均继承自StatModel这个类别,它提供了两个重载的静态函数进行训练。从源代码可以看出:

  1. template<typename _Tp> static Ptr<_Tp> train(const Ptr<TrainData>& data, const typename _Tp::Params& p, int flags=0)
  2. {
  3. Ptr<_Tp> model = _Tp::create(p);
  4. return !model.empty() && model->train(data, flags) ? model : Ptr<_Tp>();
  5. }
  6. template<typename _Tp> static Ptr<_Tp> train(InputArray samples, int layout, InputArray responses,
  7. const typename _Tp::Params& p, int flags=0)
  8. {
  9. Ptr<_Tp> model = _Tp::create(p);
  10. return !model.empty() && model->train(TrainData::create(samples, layout, responses), flags) ? model : Ptr<_Tp>();
  11. }

所以,实际上第二个函数使用的时候根据输入的参数samples, layout和response构造了TrainData类对象,并调用了第一个函数。这里的TrainData类即保存ml算法使用数据的类,这里不做详细分析(后期会写相关文章,分析其源代码)。samples是训练样本特征的矩阵,layout参数有ROW_SAMPLE和COL_SMAPLE两个选择,说明了样本矩阵中一行还是一列代表一个样本,response矩阵和samples矩阵相对应,说明了样本的标记,本例中为1和-1. 
从上面的代码中可以看出ml算法的使用方法,实际上opencv的ml模块提供的所有分类器均继承自StatModel这个抽象类,他们的使用方法均和SVM和NormalBayesClassifier类似。其包含的所有ml算法如下:

  • NormalBayesClassifier (贝叶斯分类器~~~符合正态分布的)
  • KNearest (KNN算法)
  • SVM
  • EM
  • DTrees (决策树)
  • RTrees (随机森林)
  • Boost (boosted tree classifer)
  • ANN_MLP (人工神经网络)

后续内容:本文是使用opencv学习ml算法的初次尝试,后面还会介绍更多的相关内容,包括opencv的源码学习,更多的machine learning算法介绍,如何对图像进行分类等,敬请期待。


3 0
原创粉丝点击