opencv中使用K-近邻分类算法KNN

来源:互联网 发布:手机写作辅助软件 编辑:程序博客网 时间:2024/06/05 02:57

K-近邻(K-Nearest Neighbors, KNN)是一种很好理解的分类算法,简单说来就是从训练样本中找出K个与其最相近的样本,然后看这K个样本中哪个类别的样本多,则待判定的值(或说抽样)就属于这个类别。

KNN算法的步骤

  • 计算已知类别数据集中每个点与当前点的距离;
  • 选取与当前点距离最小的K个点;
  • 统计前K个点中每个类别的样本出现的频率;
  • 返回前K个点出现频率最高的类别作为当前点的预测分类。

OpenCV中使用CvKNearest

OpenCV中实现CvKNearest类可以实现简单的KNN训练和预测。
[cpp] view plaincopy在CODE上查看代码片派生到我的代码片
  1. int main()  
  2. {  
  3.     float labels[10] = {0,0,0,0,0,1,1,1,1,1};  
  4.     Mat labelsMat(10, 1, CV_32FC1, labels);  
  5.     cout<<labelsMat<<endl;  
  6.     float trainingData[10][2];  
  7.     srand(time(0));   
  8.     for(int i=0;i<5;i++){  
  9.         trainingData[i][0] = rand()%255+1;  
  10.         trainingData[i][1] = rand()%255+1;  
  11.         trainingData[i+5][0] = rand()%255+255;  
  12.         trainingData[i+5][1] = rand()%255+255;  
  13.     }  
  14.     Mat trainingDataMat(10, 2, CV_32FC1, trainingData);  
  15.     cout<<trainingDataMat<<endl;  
  16.     CvKNearest knn;  
  17.     knn.train(trainingDataMat,labelsMat,Mat(), false, 2 );  
  18.     // Data for visual representation  
  19.     int width = 512, height = 512;  
  20.     Mat image = Mat::zeros(height, width, CV_8UC3);  
  21.     Vec3b green(0,255,0), blue (255,0,0);  
  22.   
  23.     for (int i = 0; i < image.rows; ++i){  
  24.         for (int j = 0; j < image.cols; ++j){  
  25.             const Mat sampleMat = (Mat_<float>(1,2) << i,j);  
  26.             Mat response;  
  27.             float result = knn.find_nearest(sampleMat,1);  
  28.             if (result !=0){  
  29.                 image.at<Vec3b>(j, i)  = green;  
  30.             }  
  31.             else    
  32.                 image.at<Vec3b>(j, i)  = blue;  
  33.         }  
  34.     }  
  35.   
  36.         // Show the training data  
  37.         for(int i=0;i<5;i++){  
  38.             circle( image, Point(trainingData[i][0],  trainingData[i][1]),   
  39.                 5, Scalar(  0,   0,   0), -1, 8);  
  40.             circle( image, Point(trainingData[i+5][0],  trainingData[i+5][1]),   
  41.                 5, Scalar(255, 255, 255), -1, 8);  
  42.         }  
  43.         imshow("KNN Simple Example", image); // show it to the user  
  44.         waitKey(10000);  
  45.   
  46. }  

使用的是之前BP神经网络中的例子,分类结果如下:

预测函数find_nearest()除了输入sample参数外还有些其他的参数:
[cpp] view plaincopy在CODE上查看代码片派生到我的代码片
  1. float CvKNearest::find_nearest(const Mat& samples, int k, Mat* results=0,   
  2. const float** neighbors=0, Mat* neighborResponses=0, Mat* dist=0 )  


即,samples为样本数*特征数的浮点矩阵;K为寻找最近点的个数;results与预测结果;neibhbors为k*样本数的指针数组(输入为const,实在不知为何如此设计);neighborResponse为样本数*k的每个样本K个近邻的输出值;dist为样本数*k的每个样本K个近邻的距离。

另一个例子

OpenCV refman也提供了一个类似的示例,使用CvMat格式的输入参数:
[cpp] view plaincopy在CODE上查看代码片派生到我的代码片
  1. int main( int argc, char** argv )  
  2. {  
  3.     const int K = 10;  
  4.     int i, j, k, accuracy;  
  5.     float response;  
  6.     int train_sample_count = 100;  
  7.     CvRNG rng_state = cvRNG(-1);  
  8.     CvMat* trainData = cvCreateMat( train_sample_count, 2, CV_32FC1 );  
  9.     CvMat* trainClasses = cvCreateMat( train_sample_count, 1, CV_32FC1 );  
  10.     IplImage* img = cvCreateImage( cvSize( 500, 500 ), 8, 3 );  
  11.     float _sample[2];  
  12.     CvMat sample = cvMat( 1, 2, CV_32FC1, _sample );  
  13.     cvZero( img );  
  14.     CvMat trainData1, trainData2, trainClasses1, trainClasses2;  
  15.     // form the training samples  
  16.     cvGetRows( trainData, &trainData1, 0, train_sample_count/2 );  
  17.     cvRandArr( &rng_state, &trainData1, CV_RAND_NORMAL, cvScalar(200,200), cvScalar(50,50) );  
  18.     cvGetRows( trainData, &trainData2, train_sample_count/2, train_sample_count );  
  19.     cvRandArr( &rng_state, &trainData2, CV_RAND_NORMAL, cvScalar(300,300), cvScalar(50,50) );  
  20.     cvGetRows( trainClasses, &trainClasses1, 0, train_sample_count/2 );  
  21.     cvSet( &trainClasses1, cvScalar(1) );  
  22.     cvGetRows( trainClasses, &trainClasses2, train_sample_count/2, train_sample_count );  
  23.     cvSet( &trainClasses2, cvScalar(2) );  
  24.     // learn classifier  
  25.     CvKNearest knn( trainData, trainClasses, 0, false, K );  
  26.     CvMat* nearests = cvCreateMat( 1, K, CV_32FC1);  
  27.     for( i = 0; i < img->height; i++ )  
  28.     {  
  29.         for( j = 0; j < img->width; j++ )  
  30.         {  
  31.             sample.data.fl[0] = (float)j;  
  32.             sample.data.fl[1] = (float)i;  
  33.             // estimate the response and get the neighbors’ labels  
  34.             response = knn.find_nearest(&sample,K,0,0,nearests,0);  
  35.             // compute the number of neighbors representing the majority  
  36.             for( k = 0, accuracy = 0; k < K; k++ )  
  37.             {  
  38.                 if( nearests->data.fl[k] == response)  
  39.                     accuracy++;  
  40.             }  
  41.             // highlight the pixel depending on the accuracy (or confidence)  
  42.             cvSet2D( img, i, j, response == 1 ?  
  43.                 (accuracy > 5 ? CV_RGB(180,0,0) : CV_RGB(180,120,0)) :  
  44.                 (accuracy > 5 ? CV_RGB(0,180,0) : CV_RGB(120,120,0)) );  
  45.         }  
  46.     }  
  47.     // display the original training samples  
  48.     for( i = 0; i < train_sample_count/2; i++ )  
  49.     {  
  50.         CvPoint pt;  
  51.         pt.x = cvRound(trainData1.data.fl[i*2]);  
  52.         pt.y = cvRound(trainData1.data.fl[i*2+1]);  
  53.         cvCircle( img, pt, 2, CV_RGB(255,0,0), CV_FILLED );  
  54.         pt.x = cvRound(trainData2.data.fl[i*2]);  
  55.         pt.y = cvRound(trainData2.data.fl[i*2+1]);  
  56.         cvCircle( img, pt, 2, CV_RGB(0,255,0), CV_FILLED );  
  57.     }  
  58.     cvNamedWindow( "classifier result", 1 );  
  59.     cvShowImage( "classifier result", img );  
  60.     cvWaitKey(0);  
  61.     cvReleaseMat( &trainClasses );  
  62.     cvReleaseMat( &trainData );  
  63.     return 0;  
  64. }  
分类结果:


KNN的思想很好理解,也非常容易实现,同时分类结果较高,对异常值不敏感。但计算复杂度较高,不适于大数据的分类问题。
0 0
原创粉丝点击