KNN 算法理解
来源:互联网 发布:小学英语教学软件 编辑:程序博客网 时间:2024/05/22 03:21
kNN算法又称为k近邻分类(k-nearest neighbor classification)算法。
一、基本思想:
kNN算法的指导思想是“近朱者赤,近墨者黑”,由你的邻居来推断出你的类别。在距离空间里,如果一个样本的最接近的k个邻居里,绝大多数属于某个类别,则该样本也属于这个类别。俗话叫,“随大流”。
代表论文:Discriminant Adaptive Nearest Neighbor Classification
二、算法描述:
1、算法步骤:
step.1---初始化距离为最大值
step.2---计算未知样本和每个训练样本的距离dist
step.3---得到目前K个最临近样本中的最大距离maxdist
step.4---如果dist小于maxdist,则将该训练样本作为K-最近邻样本
step.5---重复步骤2、3、4,直到未知样本和所有训练样本的距离都算完
step.6---统计K-最近邻样本中每个类标号出现的次数
step.7---选择出现频率最大的类标号作为未知样本的类标号
2、K的选取:
如何选择一个最佳的K值取决于数据。一般情况下,在分类时较大的K值能够减小噪声的影响。但会使类别之间的界限变得模糊。比如下图:
待测样本(绿色圆圈)既可能分到红色三角形类,也可能分到蓝色正方形类。如果k取3,从图可见,待测样本的3个邻居在实线的内圆里,按多数投票结果,它属于红色三角形类,票数1:2.但是如果k取5,那么待测样本的最邻近的5个样本在虚线的圆里,按表决法,它又属于蓝色正方形类,票数2(红色三角形):3(蓝色正方形)。另外还有认为,经验规则,k一般低于训练样本数的平方根。
三、优缺点
1、优点
简单,易于理解,易于实现,无需估计参数,无需训练。适合对稀有事件进行分类(例如当流失率很低时,比如低于0.5%,构造流失预测模型)。特别适合于多分类问题(multi-modal,对象具有多个类别标签),例如根据基因特征来判断其功能分类,kNN比SVM的表现要好。
2、缺点
懒惰算法,对测试样本分类时的计算量大,内存开销大,评分慢可解释性较差,无法给出决策树那样的规则。
四、行业应用
客户流失预测、欺诈侦 测等(更适合于稀有事件的分类问题)。
五、性能问题
kNN是一种懒惰算法,平时不好好学习,考试(对测试样本分类)时才临阵磨枪(临时去找k个近邻)。懒惰的后果:构造模型很简单,但在对测试样本分类地的系统开销大,因为要扫描全部训练样本并计算距离。已经有一些方法提高计算的效率,例如压缩训练样本量等。
六:测试代码和数据集:
数据集:http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
// KNN.cpp K-最近邻分类算法//////////////////////////////////////////////////////////////////////////////////////////////////////////#include <stdlib.h>#include <stdio.h>#include <memory.h>#include <string.h>#include <iostream>#include <math.h>#include <fstream>using namespace std;//////////////////////////////////////////////////////////////////////////////////////////////////////////// 宏定义//////////////////////////////////////////////////////////////////////////////////////////////////////////#define ATTR_NUM 4 //属性数目#define MAX_SIZE_OF_TRAINING_SET 1000 //训练数据集的最大大小#define MAX_SIZE_OF_TEST_SET 100 //测试数据集的最大大小#define MAX_VALUE 10000.0 //属性最大值#define K 7//结构体struct dataVector { int ID; //ID号 char classLabel[15]; //分类标号 double attributes[ATTR_NUM]; //属性 };struct distanceStruct { int ID; //ID号 double distance; //距离 char classLabel[15]; //分类标号};//////////////////////////////////////////////////////////////////////////////////////////////////////////// 全局变量//////////////////////////////////////////////////////////////////////////////////////////////////////////struct dataVector gTrainingSet[MAX_SIZE_OF_TRAINING_SET]; //训练数据集struct dataVector gTestSet[MAX_SIZE_OF_TEST_SET]; //测试数据集struct distanceStruct gNearestDistance[K]; //K个最近邻距离int curTrainingSetSize=0; //训练数据集的大小int curTestSetSize=0; //测试数据集的大小//////////////////////////////////////////////////////////////////////////////////////////////////////////// 求 vector1=(x1,x2,...,xn)和vector2=(y1,y2,...,yn)的欧几里德距离//////////////////////////////////////////////////////////////////////////////////////////////////////////double Distance(struct dataVector vector1,struct dataVector vector2){ double dist,sum=0.0; for(int i=0;i<ATTR_NUM;i++) { sum+=(vector1.attributes[i]-vector2.attributes[i])*(vector1.attributes[i]-vector2.attributes[i]); } dist=sqrt(sum); return dist;}//////////////////////////////////////////////////////////////////////////////////////////////////////////// 得到gNearestDistance中的最大距离,返回下标//////////////////////////////////////////////////////////////////////////////////////////////////////////int GetMaxDistance(){ int maxNo=0; for(int i=1;i<K;i++) { if(gNearestDistance[i].distance>gNearestDistance[maxNo].distance) maxNo = i; } return maxNo;}//////////////////////////////////////////////////////////////////////////////////////////////////////////// 对未知样本Sample分类//////////////////////////////////////////////////////////////////////////////////////////////////////////char* Classify(struct dataVector Sample){ double dist=0; int maxid=0,freq[K],i,tmpfreq=1;; char *curClassLable=gNearestDistance[0].classLabel; memset(freq,1,sizeof(freq)); //step.1---初始化距离为最大值 for(i=0;i<K;i++) { gNearestDistance[i].distance=MAX_VALUE; } //step.2---计算K-最近邻距离 for(i=0;i<curTrainingSetSize;i++) { //step.2.1---计算未知样本和每个训练样本的距离 dist=Distance(gTrainingSet[i],Sample); //step.2.2---得到gNearestDistance中的最大距离 maxid=GetMaxDistance(); //step.2.3---如果距离小于gNearestDistance中的最大距离,则将该样本作为K-最近邻样本 if(dist<gNearestDistance[maxid].distance) { gNearestDistance[maxid].ID=gTrainingSet[i].ID; gNearestDistance[maxid].distance=dist; strcpy(gNearestDistance[maxid].classLabel,gTrainingSet[i].classLabel); } } //step.3---统计每个类出现的次数 for(i=0;i<K;i++) { for(int j=0;j<K;j++) { if((i!=j)&&(strcmp(gNearestDistance[i].classLabel,gNearestDistance[j].classLabel)==0)) { freq[i]+=1; } } } //step.4---选择出现频率最大的类标号 for(i=0;i<K;i++) { if(freq[i]>tmpfreq) { tmpfreq=freq[i]; curClassLable=gNearestDistance[i].classLabel; } } return curClassLable;}//////////////////////////////////////////////////////////////////////////////////////////////////////////// 主函数//////////////////////////////////////////////////////////////////////////////////////////////////////////void main(){ char c; char *classLabel=""; int i,j, rowNo=0,TruePositive=0,FalsePositive=0; ifstream filein("iris.data"); FILE *fp; if(filein.fail()) { cout<<"Can't open data.txt"<<endl; return; } //step.1---读文件 while(!filein.eof()) { rowNo++;//第一组数据rowNo=1 if(curTrainingSetSize>=MAX_SIZE_OF_TRAINING_SET) { cout<<"The training set has "<<MAX_SIZE_OF_TRAINING_SET<<" examples!"<<endl<<endl; break ; } //rowNo%3!=0的100组数据作为训练数据集 if(rowNo%3!=0) { gTrainingSet[curTrainingSetSize].ID=rowNo; for(int i = 0;i < ATTR_NUM;i++) { filein>>gTrainingSet[curTrainingSetSize].attributes[i]; filein>>c; } filein>>gTrainingSet[curTrainingSetSize].classLabel; curTrainingSetSize++; } //剩下rowNo%3==0的50组做测试数据集 else if(rowNo%3==0) { gTestSet[curTestSetSize].ID=rowNo; for(int i = 0;i < ATTR_NUM;i++) { filein>>gTestSet[curTestSetSize].attributes[i]; filein>>c; } filein>>gTestSet[curTestSetSize].classLabel; curTestSetSize++; } } filein.close(); //step.2---KNN算法进行分类,并将结果写到文件iris_OutPut.txt fp=fopen("iris_OutPut.txt","w+t"); //用KNN算法进行分类 fprintf(fp,"************************************程序说明***************************************\n"); fprintf(fp,"** 采用KNN算法对iris.data分类。为了操作方便,对各组数据添加rowNo属性,第一组rowNo=1!\n"); fprintf(fp,"** 共有150组数据,选择rowNo模3不等于0的100组作为训练数据集,剩下的50组做测试数据集\n"); fprintf(fp,"***********************************************************************************\n\n"); fprintf(fp,"************************************实验结果***************************************\n\n"); for(i=0;i<curTestSetSize;i++) { fprintf(fp,"************************************第%d组数据**************************************\n",i+1); classLabel =Classify(gTestSet[i]); if(strcmp(classLabel,gTestSet[i].classLabel)==0)//相等时,分类正确 { TruePositive++; } cout<<"rowNo: "; cout<<gTestSet[i].ID<<" \t"; cout<<"KNN分类结果: "; cout<<classLabel<<"(正确类标号: "; cout<<gTestSet[i].classLabel<<")\n"; fprintf(fp,"rowNo: %3d \t KNN分类结果: %s ( 正确类标号: %s )\n",gTestSet[i].ID,classLabel,gTestSet[i].classLabel); if(strcmp(classLabel,gTestSet[i].classLabel)!=0)//不等时,分类错误 { // cout<<" ***分类错误***\n"; fprintf(fp," ***分类错误***\n"); } fprintf(fp,"%d-最临近数据:\n",K); for(j=0;j<K;j++) { // cout<<gNearestDistance[j].ID<<"\t"<<gNearestDistance[j].distance<<"\t"<<gNearestDistance[j].classLabel[15]<<endl; fprintf(fp,"rowNo: %3d \t Distance: %f \tClassLable: %s\n",gNearestDistance[j].ID,gNearestDistance[j].distance,gNearestDistance[j].classLabel); } fprintf(fp,"\n"); } FalsePositive=curTestSetSize-TruePositive; fprintf(fp,"***********************************结果分析**************************************\n",i); fprintf(fp,"TP(True positive): %d\nFP(False positive): %d\naccuracy: %f\n",TruePositive,FalsePositive,double(TruePositive)/(curTestSetSize-1)); fclose(fp); return;}
- KNN算法理解
- KNN算法理解
- KNN算法理解
- KNN算法理解
- 深入理解KNN算法
- KNN算法理解
- KNN 算法理解
- KNN算法理解
- KNN算法理解
- KNN算法理解
- KNN算法理解
- KNN算法理解
- KNN算法理解
- KNN算法的个人理解
- KNN算法理解和应用
- opencv2.4.9中KNN算法理解
- 十大算法之-------Knn理解
- 机器学习(1)-KNN算法理解
- 实验一 彩色空间转换
- 图解hive运行机制
- 手把手教你写专利申请书/如何申请专利
- 我所理解的Handler的使用及其原理浅析
- ElasticSearch 5.0.0 安装部署常见错误或问题
- KNN 算法理解
- codevs1247排排站
- 服务器-Web框架配置
- MIT的《深度学习》精读(1)
- hadoop集群出现两个datanode节点互相排斥的情况解决
- Flume-ng源码解析之Channel组件
- Windows7以管理员身份运行程序
- 谈EXPORT_SYMBOL使用
- Sticks