程序片段----opencv cv::ml::KNearest knn 20170904

来源:互联网 发布:淘宝实木花架中式图片 编辑:程序博客网 时间:2024/06/06 14:14
// opencv3 knn 的实例// 样本是随机数生成的,不需要额外数据集。// knn  : k 是要设定的参数,意义是:将待测样本X最近的k个点进行比较,A类型的点最多,那么认为待测样本X是A类型。// 环境 : opencv3.0.0 \ vs2012 32 bits \ win7// 环境搭建://# 1. 新建工程,opencv基本配置//  # 2. 将 opencv\source\module\ml\src 中的 knearest.cpp 复制到新建的工程目录下,文件名改为 knearest.hpp //  # 3. 包含该头文件//  # 4. mian函数修改如下//  # 5. 编译运行。// #include <opencv2/opencv.hpp>#include <opencv2/core/core.hpp>#include <opencv2/highgui/highgui.hpp>#include <opencv2/ml/ml.hpp>#include <knearest.hpp>using namespace cv::ml;int main( ){    const int K = 10;    int i, j, k, accuracy;    float response;    int train_sample_count = 100;    cv::RNG rng_state(-1);    cv::Mat trainData(train_sample_count,2,CV_32FC1);    cv::Mat trainClasses(train_sample_count,1,CV_32FC1); /// labels    cv::Mat img(cv::Size(500,500),CV_8UC3,cv::Scalar::all (0));    float _sample[2];    cv::Mat sample(1,2,CV_32FC1,_sample); /// just 1 sample    cv::Mat trainData1, trainData2, trainClasses1, trainClasses2;    // form the training samples    trainData1 = trainData.rowRange (0,train_sample_count/2);    rng_state.fill (trainData1,CV_RAND_NORMAL,cv::Scalar(200,200),cv::Scalar(50,50));    trainData2 = trainData.rowRange (train_sample_count/2,train_sample_count);    rng_state.fill (trainData2,CV_RAND_NORMAL,cv::Scalar(300,300),cv::Scalar(50,50));    trainClasses1 = trainClasses.rowRange (0,train_sample_count/2);    trainClasses1.setTo (1);    trainClasses2 = trainClasses.rowRange (train_sample_count/2,train_sample_count);    trainClasses2.setTo (2);    // learn classifier//// cv::ml::KNearest knn( trainData, trainClasses, cv::Mat(), false, K );cv::Ptr<cv::ml::KNearest> knn = KNearest::create();knn->setDefaultK(5);knn->setIsClassifier(true);cv::Ptr<cv::ml::TrainData> tData = TrainData::create(trainData, ROW_SAMPLE, trainClasses);knn->train(tData);    cv::Mat nearests( 1, K, CV_32FC1); //// closet k points    for( i = 0; i < img.rows; i++ ) //// 将图中各点作为样本。红是1类型置信值高的,绿是2类型置信值高的,棕色是不确定    {        for( j = 0; j < img.cols; j++ )        {            sample.at<float>(0,0) = (float)j;            sample.at<float>(0,1) = (float)i;cv::Mat result(sample.size(), CV_32FC1);            // estimate the response and get the neighbors' labels/// response = knn->findNearest(sample,K,0,0,&nearests,0);response = knn->findNearest(sample, K, result, nearests);            // compute the number of neighbors representing the majority            for( k = 0, accuracy = 0; k < K; k++ )            {                if( nearests.at<float>(0,k) == response)                    accuracy++; /// 最近邻的k个中,该类型的占比            }            // highlight the pixel depending on the accuracy (or confidence)            img.at<cv::Vec3b>(i,j) = response == 1 ?                        (accuracy > 5 ? cv::Vec3b(0,0,180) : cv::Vec3b(0,120,180)) :                        (accuracy > 5 ? cv::Vec3b(0,180,0) : cv::Vec3b(0,120,120));        }    }    // display the original training samples    for( i = 0; i < train_sample_count/2; i++ ) ///// 圈 是样本    {        cv::Point pt;        pt.x = cvRound(trainData1.at<float>(i,0));        pt.y = cvRound(trainData1.at<float>(i,1));        cv::circle (img,pt,2,cv::Scalar(0,0,255),1,CV_FILLED);        pt.x = cvRound(trainData2.at<float>(i,0));        pt.y = cvRound(trainData2.at<float>(i,1));        cv::circle (img,pt,2,cv::Scalar(0,255,0),1,CV_FILLED);    }    cv::namedWindow( "classifier result", 1 );    cv::imshow( "classifier result", img );    cv::waitKey(0);    return 0;}

0 0
原创粉丝点击