k近邻分类(kNN)

来源:互联网 发布:陀螺 知乎 编辑:程序博客网 时间:2024/05/21 19:15

k近邻分类(kNN)

一、KNN原理

        kNN属于监督分类方法,原理是利用某种距离度量方式来计算未知数据与已知数据的距离,并根据距离来确定数据光谱间的相似性,选取最近的k个距离作为判定未知数据类别的依据。在分类时,kNN常用方法有:投票法,根据k个距离对应已知数据的类别进行统计,把出现次数最多的类别作为未知数据的类别;平均法,利用k个距离及其对应已知数据的类别,计算k个距离中不同类别下的平均距离,把未知数据划分到平均距离最小的类别中去;加权法,根据k个距离对应已知数据的类别,再由这k个距离的远近设置权值,距离越远权值越大,反之越小,然后计算k个距离在各类别中的加权和,最后把最小加权和对应的类别作为未知数据的类别。
       根据kNN算法的原理可知,k值的选取对分类结果影响较大,图2.7为不同k值的分类结果,图中采用投票法来辨别未知数据的类别。从图2.7可以看出不同的k得到的类别是不同的,使用kNN分类时,k值的确定将是一个难点。


二、KNN具体操作:

K最邻近密度估计技术是一种分类方法,不是聚类方法。

不是最优方法,实践中比较流行。

通俗但不一定易懂的规则是:

1.计算待分类数据和不同类中每一个数据的距离(欧氏或马氏)。

2.选出最小的前K数据个距离,这里用到选择排序法。

3.对比这前K个距离,找出K个数据中包含最多的是那个类的数据,即为待分类数据所在的类。

不通俗但严谨的规则是:

给定一个位置特征向量x和一种距离测量方法,于是有:

1.在N个训练向量外,不考虑类的标签来确定k邻近。在两类的情况下,k选为奇数,一般不是类M的倍数。

2.在K个样本之外,确定属于wi,i=1,2,...M类的向量的个数ki,显然sum(ki)=k。

3.x属于样本最大值ki的那一类wi。

如下图,看那个绿色的值,是算三角类呢还是算矩类形呢,这要看是用几NN了,要是3NN就属于三角,要是5NN就属于矩形。

至于K到底取几,不同情况都要区别对待的。


下面是相关matlab代码:

复制代码
clear all;close all;clc;%%第一个类数据和标号mu1=[0 0];  %均值S1=[0.3 0;0 0.35];  %协方差data1=mvnrnd(mu1,S1,100);   %产生高斯分布数据plot(data1(:,1),data1(:,2),'+');label1=ones(100,1);hold on;%%第二个类数据和标号mu2=[1.25 1.25];S2=[0.3 0;0 0.35];data2=mvnrnd(mu2,S2,100);plot(data2(:,1),data2(:,2),'ro');label2=label1+1;data=[data1;data2];label=[label1;label2];K=11;   %两个类,K取奇数才能够区分测试数据属于那个类%测试数据,KNN算法看这个数属于哪个类for ii=-3:0.1:3    for jj=-3:0.1:3        test_data=[ii jj];  %测试数据        label=[label1;label2];        %%下面开始KNN算法,显然这里是11NN。        %求测试数据和类中每个数据的距离,欧式距离(或马氏距离)         distance=zeros(200,1);        for i=1:200            distance(i)=sqrt((test_data(1)-data(i,1)).^2+(test_data(2)-data(i,2)).^2);        end        %选择排序法,只找出最小的前K个数据,对数据和标号都进行排序        for i=1:K            ma=distance(i);            for j=i+1:200                if distance(j)<ma                    ma=distance(j);                    label_ma=label(j);                    tmp=j;                end            end            distance(tmp)=distance(i);  %排数据            distance(i)=ma;            label(tmp)=label(i);        %排标号,主要使用标号            label(i)=label_ma;        end        cls1=0; %统计类1中距离测试数据最近的个数        for i=1:K           if label(i)==1               cls1=cls1+1;           end        end        cls2=K-cls1;    %类2中距离测试数据最近的个数                if cls1>cls2               plot(ii,jj);     %属于类1的数据画小黑点        end            endend
复制代码

代码中是两个高斯分布的类,变量取x=-3:3,y=-3:3中的数据,看看这些数据都是属于哪个类。

下面是运行效果图:


0 0
原创粉丝点击