机器学习-KNN

来源:互联网 发布:中国软件与即时 编辑:程序博客网 时间:2024/05/01 20:34

本文转载自http://blog.csdn.net/lavorange/article/details/16924705

本文不对KNN算法做过多的理论上的解释,主要是针对问题,进行算法的设计和代码的注解。

KNN算法:

优点:精度高、对异常值不敏感、无数据输入假定。

缺点:计算复杂度高、空间复杂度高。

适用数据范围:数值型和标称性。

工作原理:存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一个数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据及中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k选择不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。

K-近邻算法的一般流程:

(1)收集数据:可以使用任何方法

(2)准备数据:距离计算所需要的数值,最好是结构化的数据格式

(3)分析数据:可以使用任何方法

(4)训练算法:此步骤不适用k-邻近算法

(5)测试算法:计算错误率

(6)使用算法:首先需要输入样本数据和结构化的输出结果,然后运行k-近邻算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理。


问题一:现在我们假设一个场景,就是要为坐标上的点进行分类,如下图所示:



上图一共12个左边点,每个坐标点都有相应的坐标(x,y)以及它所属的类别A/B,那么现在需要做的就是给定一个点坐标(x1,y1),判断它属于的类别A或者B。

所有的坐标点在data.txt文件中:

[cpp] view plaincopy
  1. 0.0 1.1 A  
  2. 1.0 1.0 A  
  3. 2.0 1.0 B  
  4. 0.5 0.5 A  
  5. 2.5 0.5 B  
  6. 0.0 0.0 A  
  7. 1.0 0.0 A   
  8. 2.0 0.0 B  
  9. 3.0 0.0 B  
  10. 0.0 -1.0 A  
  11. 1.0 -1.0 A  
  12. 2.0 -1.0 B  


step1:通过类的默认构造函数去初始化训练数据集dataSet和测试数据testData。

step2:用get_distance()来计算测试数据testData和每一个训练数据dataSet[index]的距离,用map_index_dis来保存键值对<index,distance>,其中index代表第几个训练数据,distance代表第index个训练数据和测试数据的距离。

step3:将map_index_dis按照value值(即distance值)从小到大的顺序排序,然后取前k个最小的value值,用map_label_freq来记录每一个类标签出现的频率。

step4:遍历map_label_freq中的value值,返回value最大的那个key值,就是测试数据属于的类。


看一下代码KNN_0.cc:

[cpp] view plaincopy
  1. #include<iostream>  
  2. #include<map>  
  3. #include<vector>  
  4. #include<stdio.h>  
  5. #include<cmath>  
  6. #include<cstdlib>  
  7. #include<algorithm>  
  8. #include<fstream>  
  9.   
  10. using namespace std;  
  11.   
  12. typedef char tLabel;  
  13. typedef double tData;  
  14. typedef pair<int,double>  PAIR;  
  15. const int colLen = 2;  
  16. const int rowLen = 12;  
  17. ifstream fin;  
  18. ofstream fout;  
  19.   
  20. class KNN  
  21. {  
  22. private:  
  23.         tData dataSet[rowLen][colLen];  
  24.         tLabel labels[rowLen];  
  25.         tData testData[colLen];  
  26.         int k;  
  27.         map<int,double> map_index_dis;  
  28.         map<tLabel,int> map_label_freq;  
  29.         double get_distance(tData *d1,tData *d2);  
  30. public:  
  31.   
  32.         KNN(int k);  
  33.   
  34.         void get_all_distance();  
  35.           
  36.         void get_max_freq_label();  
  37.   
  38.         struct CmpByValue  
  39.         {  
  40.             bool operator() (const PAIR& lhs,const PAIR& rhs)  
  41.             {  
  42.                 return lhs.second < rhs.second;  
  43.             }  
  44.         };  
  45.           
  46. };  
  47.   
  48. KNN::KNN(int k)  
  49. {  
  50.     this->k = k;  
  51.   
  52.     fin.open("data.txt");  
  53.   
  54.     if(!fin)  
  55.     {  
  56.         cout<<"can not open the file data.txt"<<endl;  
  57.         exit(1);  
  58.     }  
  59.   
  60.     /* input the dataSet */   
  61.     for(int i=0;i<rowLen;i++)  
  62.     {  
  63.         for(int j=0;j<colLen;j++)  
  64.         {  
  65.             fin>>dataSet[i][j];  
  66.         }  
  67.         fin>>labels[i];  
  68.     }  
  69.   
  70.     cout<<"please input the test data :"<<endl;  
  71.     /* inuput the test data */  
  72.     for(int i=0;i<colLen;i++)  
  73.         cin>>testData[i];  
  74.       
  75. }  
  76.   
  77. /* 
  78.  * calculate the distance between test data and dataSet[i] 
  79.  */  
  80. double KNN:: get_distance(tData *d1,tData *d2)  
  81. {  
  82.     double sum = 0;  
  83.     for(int i=0;i<colLen;i++)  
  84.     {  
  85.         sum += pow( (d1[i]-d2[i]) , 2 );  
  86.     }  
  87.   
  88. //  cout<<"the sum is = "<<sum<<endl;  
  89.     return sqrt(sum);  
  90. }  
  91.   
  92. /* 
  93.  * calculate all the distance between test data and each training data 
  94.  */  
  95. void KNN:: get_all_distance()  
  96. {  
  97.     double distance;  
  98.     int i;  
  99.     for(i=0;i<rowLen;i++)  
  100.     {  
  101.         distance = get_distance(dataSet[i],testData);  
  102.         //<key,value> => <i,distance>  
  103.         map_index_dis[i] = distance;  
  104.     }  
  105.   
  106.     //traverse the map to print the index and distance  
  107.     map<int,double>::const_iterator it = map_index_dis.begin();  
  108.     while(it!=map_index_dis.end())  
  109.     {  
  110.         cout<<"index = "<<it->first<<" distance = "<<it->second<<endl;  
  111.         it++;  
  112.     }  
  113. }  
  114.   
  115. /* 
  116.  * check which label the test data belongs to to classify the test data  
  117.  */  
  118. void KNN:: get_max_freq_label()  
  119. {  
  120.     //transform the map_index_dis to vec_index_dis  
  121.     vector<PAIR> vec_index_dis( map_index_dis.begin(),map_index_dis.end() );  
  122.     //sort the vec_index_dis by distance from low to high to get the nearest data  
  123.     sort(vec_index_dis.begin(),vec_index_dis.end(),CmpByValue());  
  124.   
  125.     for(int i=0;i<k;i++)  
  126.     {  
  127.         cout<<"the index = "<<vec_index_dis[i].first<<" the distance = "<<vec_index_dis[i].second<<" the label = "<<labels[vec_index_dis[i].first]<<" the coordinate ( "<<dataSet[ vec_index_dis[i].first ][0]<<","<<dataSet[ vec_index_dis[i].first ][1]<<" )"<<endl;  
  128.         //calculate the count of each label  
  129.         map_label_freq[ labels[ vec_index_dis[i].first ]  ]++;  
  130.     }  
  131.   
  132.     map<tLabel,int>::const_iterator map_it = map_label_freq.begin();  
  133.     tLabel label;  
  134.     int max_freq = 0;  
  135.     //find the most frequent label  
  136.     while( map_it != map_label_freq.end() )  
  137.     {  
  138.         if( map_it->second > max_freq )  
  139.         {  
  140.             max_freq = map_it->second;  
  141.             label = map_it->first;  
  142.         }  
  143.         map_it++;  
  144.     }  
  145.     cout<<"The test data belongs to the "<<label<<" label"<<endl;  
  146. }  
  147.   
  148. int main()  
  149. {  
  150.     int k ;  
  151.     cout<<"please input the k value : "<<endl;  
  152.     cin>>k;  
  153.     KNN knn(k);  
  154.     knn.get_all_distance();  
  155.     knn.get_max_freq_label();  
  156.     system("pause");   
  157.     return 0;  
  158. }  


我们来测试一下这个分类器(k=5):

testData(5.0,5.0):



testData(-5.0,-5.0):



testData(1.6,0.5):



分类结果的正确性可以通过坐标系来判断,可以看出结果都是正确的。


问题二:使用k-近邻算法改进约会网站的匹配效果

情景如下:我的朋友海伦一直使用在线约会网站寻找合适自己的约会对象。尽管约会网站会推荐不同的人选,但她没有从中找到喜欢的人。经过一番总结,她发现曾交往过三种类型的人:

>不喜欢的人

>魅力一般的人

>极具魅力的人

尽管发现了上述规律,但海伦依然无法将约会网站推荐的匹配对象归入恰当的分类。她觉得可以在周一到周五约会哪些魅力一般的人,而周末则更喜欢与那些极具魅力的人为伴。海伦希望我们的分类软件可以更好的帮助她将匹配对象划分到确切的分类中。此外海伦还收集了一些约会网站未曾记录的数据信息,她认为这些数据更有助于匹配对象的归类。

海伦已经收集数据一段时间。她把这些数据存放在文本文件datingTestSet.txt(文件链接:http://yunpan.cn/QUL6SxtiJFPfN)中,每个样本占据一行,总共有1000行。海伦的样本主要包含3中特征:

>每年获得的飞行常客里程数

>玩视频游戏所耗时间的百分比

>每周消费的冰淇淋公升数


数据预处理:归一化数据

我们可以看到,每年获取的飞行常客里程数对于计算结果的影响将远大于其他两个特征。而产生这种现象的唯一原因,仅仅是因为飞行常客书远大于其他特征值。但是这三种特征是同等重要的,因此作为三个等权重的特征之一,飞行常客数不应该如此严重地影响到计算结果。

处理这种不同取值范围的特征值时,我们通常采用的方法是数值归一化,如将取值范围处理为0到1或者-1到1之间。

公式为:newValue = (oldValue - min) / (max - min)

其中min和max分别是数据集中的最小特征值和最大特征值。我们增加一个auto_norm_data函数来归一化数据。

同事还要设计一个get_error_rate来计算分类的错误率,选总体数据的10%作为测试数据,90%作为训练数据,当然也可以自己设定百分比。

其他的算法设计都与问题一类似。


代码如下KNN_2.cc(k=7):

[cpp] view plaincopy
  1. /* add the get_error_rate function */  
  2.   
  3. #include<iostream>  
  4. #include<map>  
  5. #include<vector>  
  6. #include<stdio.h>  
  7. #include<cmath>  
  8. #include<cstdlib>  
  9. #include<algorithm>  
  10. #include<fstream>  
  11.   
  12. using namespace std;  
  13.   
  14. typedef string tLabel;  
  15. typedef double tData;  
  16. typedef pair<int,double>  PAIR;  
  17. const int MaxColLen = 10;  
  18. const int MaxRowLen = 10000;  
  19. ifstream fin;  
  20. ofstream fout;  
  21.   
  22. class KNN  
  23. {  
  24. private:  
  25.         tData dataSet[MaxRowLen][MaxColLen];  
  26.         tLabel labels[MaxRowLen];  
  27.         tData testData[MaxColLen];  
  28.         int rowLen;  
  29.         int colLen;  
  30.         int k;  
  31.         int test_data_num;  
  32.         map<int,double> map_index_dis;  
  33.         map<tLabel,int> map_label_freq;  
  34.         double get_distance(tData *d1,tData *d2);  
  35. public:  
  36.         KNN(int k , int rowLen , int colLen , char *filename);  
  37.         void get_all_distance();  
  38.         tLabel get_max_freq_label();  
  39.         void auto_norm_data();  
  40.         void get_error_rate();  
  41.         struct CmpByValue  
  42.         {  
  43.             bool operator() (const PAIR& lhs,const PAIR& rhs)  
  44.             {  
  45.                 return lhs.second < rhs.second;  
  46.             }  
  47.         };  
  48.   
  49.         ~KNN();   
  50. };  
  51.   
  52. KNN::~KNN()  
  53. {  
  54.     fin.close();  
  55.     fout.close();  
  56.     map_index_dis.clear();  
  57.     map_label_freq.clear();  
  58. }  
  59.   
  60. KNN::KNN(int k , int row ,int col , char *filename)  
  61. {  
  62.     this->rowLen = row;  
  63.     this->colLen = col;  
  64.     this->k = k;  
  65.     test_data_num = 0;  
  66.       
  67.     fin.open(filename);  
  68.     fout.open("result.txt");  
  69.   
  70.     if( !fin || !fout )  
  71.     {  
  72.         cout<<"can not open the file"<<endl;  
  73.         exit(0);  
  74.     }  
  75.   
  76.     for(int i=0;i<rowLen;i++)  
  77.     {  
  78.         for(int j=0;j<colLen;j++)  
  79.         {  
  80.             fin>>dataSet[i][j];  
  81.             fout<<dataSet[i][j]<<" ";  
  82.         }  
  83.         fin>>labels[i];  
  84.         fout<<labels[i]<<endl;  
  85.     }  
  86.   
  87. }  
  88.   
  89. void KNN:: get_error_rate()  
  90. {  
  91.     int i,j,count = 0;  
  92.     tLabel label;  
  93.     cout<<"please input the number of test data : "<<endl;  
  94.     cin>>test_data_num;  
  95.     for(i=0;i<test_data_num;i++)  
  96.     {  
  97.         for(j=0;j<colLen;j++)  
  98.         {  
  99.             testData[j] = dataSet[i][j];          
  100.         }  
  101.           
  102.         get_all_distance();  
  103.         label = get_max_freq_label();  
  104.         if( label!=labels[i] )  
  105.             count++;  
  106.         map_index_dis.clear();  
  107.         map_label_freq.clear();  
  108.     }  
  109.     cout<<"the error rate is = "<<(double)count/(double)test_data_num<<endl;  
  110. }  
  111.   
  112. double KNN:: get_distance(tData *d1,tData *d2)  
  113. {  
  114.     double sum = 0;  
  115.     for(int i=0;i<colLen;i++)  
  116.     {  
  117.         sum += pow( (d1[i]-d2[i]) , 2 );  
  118.     }  
  119.   
  120.     //cout<<"the sum is = "<<sum<<endl;  
  121.     return sqrt(sum);  
  122. }  
  123.   
  124. void KNN:: get_all_distance()  
  125. {  
  126.     double distance;  
  127.     int i;  
  128.     for(i=test_data_num;i<rowLen;i++)  
  129.     {  
  130.         distance = get_distance(dataSet[i],testData);  
  131.         map_index_dis[i] = distance;  
  132.     }  
  133.   
  134. //  map<int,double>::const_iterator it = map_index_dis.begin();  
  135. //  while(it!=map_index_dis.end())  
  136. //  {  
  137. //      cout<<"index = "<<it->first<<" distance = "<<it->second<<endl;  
  138. //      it++;  
  139. //  }  
  140.   
  141. }  
  142.   
  143. tLabel KNN:: get_max_freq_label()  
  144. {  
  145.     vector<PAIR> vec_index_dis( map_index_dis.begin(),map_index_dis.end() );  
  146.     sort(vec_index_dis.begin(),vec_index_dis.end(),CmpByValue());  
  147.   
  148.     for(int i=0;i<k;i++)  
  149.     {  
  150.         cout<<"the index = "<<vec_index_dis[i].first<<" the distance = "<<vec_index_dis[i].second<<" the label = "<<labels[ vec_index_dis[i].first ]<<" the coordinate ( ";  
  151.         int j;  
  152.         for(j=0;j<colLen-1;j++)  
  153.         {  
  154.             cout<<dataSet[ vec_index_dis[i].first ][j]<<",";  
  155.         }  
  156.         cout<<dataSet[ vec_index_dis[i].first ][j]<<" )"<<endl;  
  157.   
  158.         map_label_freq[ labels[ vec_index_dis[i].first ]  ]++;  
  159.     }  
  160.   
  161.     map<tLabel,int>::const_iterator map_it = map_label_freq.begin();  
  162.     tLabel label;  
  163.     int max_freq = 0;  
  164.     while( map_it != map_label_freq.end() )  
  165.     {  
  166.         if( map_it->second > max_freq )  
  167.         {  
  168.             max_freq = map_it->second;  
  169.             label = map_it->first;  
  170.         }  
  171.         map_it++;  
  172.     }  
  173.     cout<<"The test data belongs to the "<<label<<" label"<<endl;  
  174.     return label;  
  175. }  
  176.   
  177. void KNN::auto_norm_data()  
  178. {  
  179.     tData maxa[colLen] ;  
  180.     tData mina[colLen] ;  
  181.     tData range[colLen] ;  
  182.     int i,j;  
  183.   
  184.     for(i=0;i<colLen;i++)  
  185.     {  
  186.         maxa[i] = max(dataSet[0][i],dataSet[1][i]);  
  187.         mina[i] = min(dataSet[0][i],dataSet[1][i]);  
  188.     }  
  189.   
  190.     for(i=2;i<rowLen;i++)  
  191.     {  
  192.         for(j=0;j<colLen;j++)  
  193.         {  
  194.             if( dataSet[i][j]>maxa[j] )  
  195.             {  
  196.                 maxa[j] = dataSet[i][j];  
  197.             }  
  198.             else if( dataSet[i][j]<mina[j] )  
  199.             {  
  200.                 mina[j] = dataSet[i][j];  
  201.             }  
  202.         }  
  203.     }  
  204.   
  205.     for(i=0;i<colLen;i++)  
  206.     {  
  207.         range[i] = maxa[i] - mina[i] ;   
  208.         //normalize the test data set  
  209.         testData[i] = ( testData[i] - mina[i] )/range[i] ;  
  210.     }  
  211.   
  212.     //normalize the training data set  
  213.     for(i=0;i<rowLen;i++)  
  214.     {  
  215.         for(j=0;j<colLen;j++)  
  216.         {  
  217.             dataSet[i][j] = ( dataSet[i][j] - mina[j] )/range[j];  
  218.         }  
  219.     }  
  220. }  
  221.   
  222. int main(int argc , char** argv)  
  223. {  
  224.     int k,row,col;  
  225.     char *filename;  
  226.       
  227.     if( argc!=5 )  
  228.     {  
  229.         cout<<"The input should be like this : ./a.out k row col filename"<<endl;  
  230.         exit(1);  
  231.     }  
  232.   
  233.     k = atoi(argv[1]);  
  234.     row = atoi(argv[2]);  
  235.     col = atoi(argv[3]);  
  236.     filename = argv[4];  
  237.   
  238.     KNN knn(k,row,col,filename);  
  239.   
  240.     knn.auto_norm_data();  
  241.     knn.get_error_rate();  
  242. //  knn.get_all_distance();  
  243. //  knn.get_max_freq_label();  
  244.       
  245.     return 0;  
  246. }  

makefile:

[cpp] view plaincopy
  1. target:  
  2.     g++ KNN_2.cc  
  3.         ./a.out 7 1000 3 datingTestSet.txt  


结果:

可以看到:在测试数据为10%和训练数据90%的比例下,可以看到错误率为4%,相对来讲还是很准确的。


构建完整可用系统:

已经通过使用数据对分类器进行了测试,现在可以使用分类器为海伦来对人进行分类。

代码KNN_1.cc(k=7):

[cpp] view plaincopy
  1. /* add the auto_norm_data */  
  2.   
  3. #include<iostream>  
  4. #include<map>  
  5. #include<vector>  
  6. #include<stdio.h>  
  7. #include<cmath>  
  8. #include<cstdlib>  
  9. #include<algorithm>  
  10. #include<fstream>  
  11.   
  12. using namespace std;  
  13.   
  14. typedef string tLabel;  
  15. typedef double tData;  
  16. typedef pair<int,double>  PAIR;  
  17. const int MaxColLen = 10;  
  18. const int MaxRowLen = 10000;  
  19. ifstream fin;  
  20. ofstream fout;  
  21.   
  22. class KNN  
  23. {  
  24. private:  
  25.         tData dataSet[MaxRowLen][MaxColLen];  
  26.         tLabel labels[MaxRowLen];  
  27.         tData testData[MaxColLen];  
  28.         int rowLen;  
  29.         int colLen;  
  30.         int k;  
  31.         map<int,double> map_index_dis;  
  32.         map<tLabel,int> map_label_freq;  
  33.         double get_distance(tData *d1,tData *d2);  
  34. public:  
  35.         KNN(int k , int rowLen , int colLen , char *filename);  
  36.         void get_all_distance();  
  37.         tLabel get_max_freq_label();  
  38.         void auto_norm_data();  
  39.         struct CmpByValue  
  40.         {  
  41.             bool operator() (const PAIR& lhs,const PAIR& rhs)  
  42.             {  
  43.                 return lhs.second < rhs.second;  
  44.             }  
  45.         };  
  46.   
  47.         ~KNN();   
  48. };  
  49.   
  50. KNN::~KNN()  
  51. {  
  52.     fin.close();  
  53.     fout.close();  
  54.     map_index_dis.clear();  
  55.     map_label_freq.clear();  
  56. }  
  57.   
  58. KNN::KNN(int k , int row ,int col , char *filename)  
  59. {  
  60.     this->rowLen = row;  
  61.     this->colLen = col;  
  62.     this->k = k;  
  63.       
  64.     fin.open(filename);  
  65.     fout.open("result.txt");  
  66.   
  67.     if( !fin || !fout )  
  68.     {  
  69.         cout<<"can not open the file"<<endl;  
  70.         exit(0);  
  71.     }  
  72.   
  73.     //input the training data set  
  74.     for(int i=0;i<rowLen;i++)  
  75.     {  
  76.         for(int j=0;j<colLen;j++)  
  77.         {  
  78.             fin>>dataSet[i][j];  
  79.             fout<<dataSet[i][j]<<" ";  
  80.         }  
  81.         fin>>labels[i];  
  82.         fout<<labels[i]<<endl;  
  83.     }  
  84.   
  85.     //input the test data  
  86.     cout<<"frequent flier miles earned per year?";  
  87.     cin>>testData[0];  
  88.     cout<<"percentage of time spent playing video games?";  
  89.     cin>>testData[1];  
  90.     cout<<"liters of ice cream consumed per year?";  
  91.     cin>>testData[2];  
  92. }  
  93.   
  94. double KNN:: get_distance(tData *d1,tData *d2)  
  95. {  
  96.     double sum = 0;  
  97.     for(int i=0;i<colLen;i++)  
  98.     {  
  99.         sum += pow( (d1[i]-d2[i]) , 2 );  
  100.     }  
  101.   
  102.     return sqrt(sum);  
  103. }  
  104.   
  105. void KNN:: get_all_distance()  
  106. {  
  107.     double distance;  
  108.     int i;  
  109.     for(i=0;i<rowLen;i++)  
  110.     {  
  111.         distance = get_distance(dataSet[i],testData);  
  112.         map_index_dis[i] = distance;  
  113.     }  
  114.   
  115. //  map<int,double>::const_iterator it = map_index_dis.begin();  
  116. //  while(it!=map_index_dis.end())  
  117. //  {  
  118. //      cout<<"index = "<<it->first<<" distance = "<<it->second<<endl;  
  119. //      it++;  
  120. //  }  
  121.   
  122. }  
  123.   
  124. tLabel KNN:: get_max_freq_label()  
  125. {  
  126.     vector<PAIR> vec_index_dis( map_index_dis.begin(),map_index_dis.end() );  
  127.     sort(vec_index_dis.begin(),vec_index_dis.end(),CmpByValue());  
  128.   
  129.     for(int i=0;i<k;i++)  
  130.     {  
  131.         /*     
  132.         cout<<"the index = "<<vec_index_dis[i].first<<" the distance = "<<vec_index_dis[i].second<<" the label = "<<labels[ vec_index_dis[i].first ]<<" the coordinate ( "; 
  133.         int j; 
  134.         for(j=0;j<colLen-1;j++) 
  135.         { 
  136.             cout<<dataSet[ vec_index_dis[i].first ][j]<<","; 
  137.         } 
  138.         cout<<dataSet[ vec_index_dis[i].first ][j]<<" )"<<endl; 
  139.         */  
  140.         map_label_freq[ labels[ vec_index_dis[i].first ]  ]++;  
  141.     }  
  142.   
  143.     map<tLabel,int>::const_iterator map_it = map_label_freq.begin();  
  144.     tLabel label;  
  145.     int max_freq = 0;  
  146.     /*traverse the map_label_freq to get the most frequent label*/  
  147.     while( map_it != map_label_freq.end() )  
  148.     {  
  149.         if( map_it->second > max_freq )  
  150.         {  
  151.             max_freq = map_it->second;  
  152.             label = map_it->first;  
  153.         }  
  154.         map_it++;  
  155.     }  
  156.     return label;  
  157. }  
  158.   
  159. /* 
  160.  * normalize the training data set 
  161.  */  
  162. void KNN::auto_norm_data()  
  163. {  
  164.     tData maxa[colLen] ;  
  165.     tData mina[colLen] ;  
  166.     tData range[colLen] ;  
  167.     int i,j;  
  168.   
  169.     for(i=0;i<colLen;i++)  
  170.     {  
  171.         maxa[i] = max(dataSet[0][i],dataSet[1][i]);  
  172.         mina[i] = min(dataSet[0][i],dataSet[1][i]);  
  173.     }  
  174.   
  175.     for(i=2;i<rowLen;i++)  
  176.     {  
  177.         for(j=0;j<colLen;j++)  
  178.         {  
  179.             if( dataSet[i][j]>maxa[j] )  
  180.             {  
  181.                 maxa[j] = dataSet[i][j];  
  182.             }  
  183.             else if( dataSet[i][j]<mina[j] )  
  184.             {  
  185.                 mina[j] = dataSet[i][j];  
  186.             }  
  187.         }  
  188.     }  
  189.   
  190.     for(i=0;i<colLen;i++)  
  191.     {  
  192.         range[i] = maxa[i] - mina[i] ;   
  193.         //normalize the test data set  
  194.         testData[i] = ( testData[i] - mina[i] )/range[i] ;  
  195.     }  
  196.   
  197.     //normalize the training data set  
  198.     for(i=0;i<rowLen;i++)  
  199.     {  
  200.         for(j=0;j<colLen;j++)  
  201.         {  
  202.             dataSet[i][j] = ( dataSet[i][j] - mina[j] )/range[j];  
  203.         }  
  204.     }  
  205. }  
  206.   
  207. int main(int argc , char** argv)  
  208. {  
  209.     int k,row,col;  
  210.     char *filename;  
  211.       
  212.     if( argc!=5 )  
  213.     {  
  214.         cout<<"The input should be like this : ./a.out k row col filename"<<endl;  
  215.         exit(1);  
  216.     }  
  217.   
  218.     k = atoi(argv[1]);  
  219.     row = atoi(argv[2]);  
  220.     col = atoi(argv[3]);  
  221.     filename = argv[4];  
  222.   
  223.     KNN knn(k,row,col,filename);  
  224.   
  225.     knn.auto_norm_data();  
  226.     knn.get_all_distance();  
  227.     cout<<"You will probably like this person : "<<knn.get_max_freq_label()<<endl;  
  228.       
  229.     return 0;  
  230. }  


makefile:

[cpp] view plaincopy
  1. target:  
  2.     g++ KNN_1.cc  
  3.         ./a.out 7 1000 3 datingTestSet.txt  

结果:



KNN_1.cc和KNN_2.cc的差别就在于后者对分类器的性能(即分类错误率)进行分析,而前者直接对具体实际的数据进行了分类。
0 0
原创粉丝点击