KNN分类器

来源:互联网 发布:网络推广工作基本做法 编辑:程序博客网 时间:2024/06/07 10:56

                                                                   转载请注明:http://blog.csdn.net/suky520

理论部分:



评述:

 1)KNN方法主要依据周围有限的临近样本,而不是依靠判别类域的方法来判断所属类别。因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。

 2)当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数,这里可以采用权值的方法(和该样本距离小的邻居权值大)来改进。

 3)该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。

 4)目前常用的解决方法是事先对已知样本点进行处理,事先去除对分类作用不大的样本。

 5)该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分。


KNN二分类器实例:

数据:数据集有3000个样本,每个样本是三维的,其中两个位是特征,一位是标签。选择前面2500个样本作为训练数据,后面500个样本作为测试数据。这里要求使用KNN来鉴别这些测试数据。

打开数据集

代码实现:

load data1;double temp;distance=zeros(2,2500);k=5;right=0;%对500个测试数据,测试每一个样本点类别,并统计正确率for m=2501:1:3000  %测试循环    %第一步:求待测点到已知类别的数据集的距离    for i=1:1:2500        distance(1,i)=sqrt( (data(1,i)-data(1,m))^2 + (data(2,i)-data(2,m))^2 );        distance(2,i)=data(3,i);    end    %第二步:对距离进行排序(从小到大)    [distance(1,:),ind]=sort(distance(1,:),2);    distance(2,:) = distance(2,ind);    sum1=0;    sum2=0;    %第三步:选择前面k个距离对应的点,并且统计各个类别的频数(频率)    for i=1:1:k       if(distance(2,i)==1) %类别1(+1)         sum1=sum1+1;        else         sum2=sum2+1;       %类别2(-1)       end    end    %第四步:将出现频率(频率)最大的类别作为测试点的类别。    %这里在统计正确率    if(((sum1>sum2)&&(data(3,m)==1))||((sum1<sum2)&&(data(3,m)==-1)))       right=right+1;    endendarr=right/500

测试结果的正确率:100%(可能是样本问题吧??)

KNN进阶(多类情况):

对于图像数据,如0-9的手写字母识别问题,每个字母的个数近200,图像大小为32x32的0-1矩阵,训练样本近2000,测试样本近1000。

下面是我使用matlab代码实现的。

第一:图像数据转换列向量形式(img2vector.m)

function imgVector=img2vector(filename)%将32*32的数字图像转换为1024*1的列向量rows = 32;cols = 32;imgVector = zeros(rows * cols,1); %列向量fid = fopen(filename,'r');for row=1:rows    tline = fgetl(fid);  %读取txt文件中一行字符串    for col=1:cols        %将字符转换为数字保存到向量中        imgVector((row-1)*32 + col) = tline(col) - 48;     end   end   fclose(fid);

第二:加载数据集(包括训练数据和测试数据)loadDataSet.m

%数据的格式:一列代表一个样本(包含了特征以及标签(最后一位))%得到训练样本集file = dir('.\digits\trainingDigits\*.txt');train_x=zeros(32*32,length(file)); %train_y=zeros(1,length(file));for n=1:length(file)    filename = file(n).name;    label = regexp(filename,'_','split'); %正则法则,按'_'分割字符串    train_y(n)=cell2mat(label(1)) - 48;   %cell2mat把cell类型转换为数字型    filename = strcat('.\digits\trainingDigits\',filename);    train_x(:,n)=img2vector(filename);end%得到测试样本集file = dir('.\digits\testDigits\*.txt');test_x=zeros(32*32,length(file)); %test_y=zeros(1,length(file));for n=1:length(file)    filename = file(n).name;    label = regexp(filename,'_','split'); %正则法则,按'_'分割字符串    test_y(n)=cell2mat(label(1)) - 48;    %cell2mat把cell类型转换为数字型    filename = strcat('.\digits\testDigits\',filename);    test_x(:,n)=img2vector(filename);end%训练数据trainData=[train_x ;train_y];%测试数据testData=[test_x ; test_y ];save DATA trainData testData     
第三:测试手写字母(testHandWritingClass.m):

clc,cleartic%加载数据load DATA.mattrain_x= trainData(1:end-1,:);  %训练样本train_y= trainData(end,:);      %训练样本的标签(类别)test_x = testData(1:end-1,:);   %测试样本test_y = testData(end,:);       %测试样本的标签(类别)[n,m]=size(test_x); %m是测试样本的个数numTestSamples = m;right = 0;k=3;%测试数据for i = 1:1:numTestSamples    %knn分类器    predict = knn_Classify(test_x(:,i), train_x, train_y, k);    if predict == test_y(i)        right = right + 1;    endend%精度accuray = right/numTestSamplestoc
第四:knn分类器(knn_Classify.m)

function  predict = knn_Classify(newInput, dataSet, labels, k)  %newInput:测试数据(是一个列向量)  %dataSet: 训练样本集(其中,一列代表一个样本)  %labels:  训练样本的标签(类别)    %第一步:求待测点到已知类别的数据集的距离  [n,m]=size(dataSet);  distance=zeros(2,m);  for i=1:1:m     distance(1,i)=sqrt(dot(dataSet(:,i)-newInput,dataSet(:,i)-newInput));     distance(2,i)=labels(i);  end    %第二步:对距离进行排序(从小到大)  [distance(1,:),ind]=sort(distance(1,:),2);  distance(2,:) = distance(2,ind);      %第三步:选择前面k个距离对应的点,并且统计各个类别的频数(频率)  class=unique(labels);      %0-9个类别,unique(去除重复的元素)  count=zeros(2,length(class));  %用于各个类别的个数      for j=1:length(class)     sum=0;     %统计在前面k个距离对应点中各个类别出现频数     for i=1:1:k       if(distance(2,i)==class(j)) %类别           sum= sum + 1;                       end            end     count(:,j) = [sum;class(j)];      end     %第四步:将出现频率(频率)最大的类别作为测试点的类别。 [val,ind]=max(count(1,:)); %ind是最大值val所在位置 predict = count(2,ind);          

测试结果为:

accuray =    0.9873Elapsed time is 61.209153 seconds.


参考1:http://blog.csdn.net/zouxy09/article/details/16955347

0 0
原创粉丝点击