KNN算法详解

来源:互联网 发布:java模板引擎 编辑:程序博客网 时间:2024/05/22 15:40

K-Nearest Neighbor Classifiction


1. KNN算法是怎么来的?


   猜猜看:最后一行未知电影属于什么类型的电影?


电影名称

打斗次数

接吻次数

电影类型

Califomia Man

3

104

Romance

He's Not Really into Dudes

2

100

Romance

Beautiful Woman

1

81

Romance

Kevin Longblade

101

10

Action

Robo Slayer 3000

99

5

Action

Amped II

98

5

Action

Unkown

18

90

Unkown



如果我们把每部电影当作是平面上的一个点,打斗次数表示X坐标,接吻次数表示Y坐标,那么可以得到下面的点


X坐标

Y坐标

点类型

A点

3

104

Romance

B点

2

100

Romance

C点

1

81

Romance

D点

101

10

Action

E点

99

5

Action

F点

98

5

Action

G

18

90

Unkown

   


再看另一个例子,想一想:下面图片中只有三种豆,有三个豆是未知的种类,如何判定他们的种类?


   

提供一种思路,即:未知的豆离哪种豆最近就认为未知豆和该豆是同一种类。由此,我们引出最近邻算法的定义:为了判定未知样本的类别,以全部训练样本作为挖个好看哦,计算未知样本与所有训练样本的距离,并以最近邻者的类别作为决策未知样本类别的唯一依据。但是,最近邻算法明显是存在缺陷的,我们来看下面这个例子。

   

问题:有一个未知形状X(图中绿色的圆点),如何判断X是什么形状?


   如果采用最近邻算法,我们容易认为该图形为正方形。然而,在离该点稍远处有较多的三角形。或许,该未知点被认为是三角形更为合理。

   显然,通过上面的例子我们可以明显发现最近邻算法的缺陷——对噪声数据过于敏感,为了解决这个问题,我们可以把位置样本周边的最多个最近样本计算在内,扩大参与决策的样本量,以避免个别数据直接决定决策结果。由此,我们引进K-最近邻算法。

2. KNN算法的实现步骤

step.1 -- 初始化距离为最大值;

step.2 -- 计算未知样本和每个训练样本的距离dist;

step.3 -- 得到目前K个最临近样本中的最大距离maxdist;

step.4 -- 如果dist小于maxdist,则将该训练样本作为K-最近邻样本;

step.5 -- 重复步骤2、3、4,直到未知样本和所有训练样本的距离都算完;

step.6 -- 统计K个最近邻样本中每个类别出现的次数;

step.7 -- 选择出现频率最大的类别作为未知样本的类别。

3. KNN算法的缺陷

   观察下面的例子,我们看到,对于位置样本X,通过KNN算法,我们显然可以得到X应属于红点,但对于位置样本Y,通过KNN算法我们似乎得到了Y应属于蓝点的结论,而这个结论直观来看并没有说服力。


   由上面的例子可见:该算法在分类时有个重要的不中是,当样本不平衡时,即:一个类的样本容量很大,而其他类样本数量很小时,很可能导致当输入一个未知样本时,该样本的K个邻居中大数量类的样本占多数。但是这类样本并不接近目标样本,而数量小的这类样本很靠近目标样本。这个时候,我们有理由认为该位置样本属于数量小的样本所属的一类,但是,KNN却不关心这个问题,它只关心哪类样本的数量最多,而不去把距离远近考虑在内,因此,我们可以采用权值的方法来改进。

   和该样本距离小的邻居权值大,和该样本距离大的邻居权值相对较小,由此,将距离远近的因素也考虑在内,避免一个样本过大导致误判的情况。

   此外,从算法实现的过程中,该算法还存在两个严重的问题,第一个是需要存储全部的训练样本;第二个是需要进行繁复的距离计算。

 

5. KNN算法的MATLAB实现

<span style="font-size:18px;">clear all;close all;clc; % 生成样本类1mu1 = [0,0];sigma1 = [0.8,0;0,0.6];data1 = mvnrnd(mu1,sigma1,200);label1 = ones(200,1);plot(data1(:,1),data1(:,2),'o');hold on; % 生成样本类2mu2 = [2.2,1.9];sigma2 = [1.3,0;0,1.1];data2 = mvnrnd(mu2,sigma2,200);label2 = label1+1;plot(data2(:,1),data2(:,2),'r+');hold on; % 样本和K值data=[data1;data2];label=[label1;label2];K=10; % 测试for ii=-3:0.1:6       forjj = -3 : 0.1 : 6              test_data= [ii jj];              label= [label1; label2];              distance= zeros(400,1);              %  计算未知点与样本点的距离              fori = 1:400                     distance(i)= sqrt((test_data(1)-data(i,1)).^2+(test_data(2)-data(i,2)).^2);              end                   %  排序              fori = 1:400                     ma= distance(i);                     forj = i+1:400                            ifdistance(j)<ma                                   ma= distance(j);                                   label_ma= label(j);                                   tmp= j;                            end                     end                                         distance(tmp)= distance(i);                     distance(i)= ma;                     label(tmp)= label(i);                     label(i)= label_ma;              end                   %  统计最近K个样本中类1的个数              num1= 0;              fori = 1:K                     iflabel(i) == 1                            num1= num1 + 1;                     end              end                           num2= K - num1;                           ifnum1>num2                  plot(ii,jj,'r*');              else                     plot(ii,jj);              end             endend </span>



0 0