机器学习:KNN算法(MATLAB实现)

来源:互联网 发布:空巢老人调查数据 编辑:程序博客网 时间:2024/05/18 00:38

   K-近邻算法的思想如下:首先,计算新样本与训练样本之间的距离,找到距离最近的K 个邻居;然后,根据这些邻居所属的类别来判定新样本的类别,如果它们都属于同一个类别,那么新样本也属于这个类;否则,对每个后选类别进行评分,按照某种规则确定新样本的类别。(统计出现的频率)

该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分当K值较小时可能产生过拟合,因为训练误差很小,但是测试误差可能很大;相反,当K值较大时可能产生欠拟合。

算法伪代码

对未知类别属性的数据集中的每个点依次执行以下操作:

(1)    计算已知类别的数据集中的点与当前点之间的距离;

(2)    按照距离递增次序排序;

(3)    选取与当前点距离最小的K个点;

(4)    确定前K个点所在类别的出现频率;

(5)    返回前K个点出现频率最高的类别作为当前点的预测分类。

 

[plain] view plaincopy在CODE上查看代码片派生到我的代码片
  1. %  
  2. %手写数字识别系统的测试代码  
  3. %  
  4. function handWritingTest()  
  5.     tic; %开始计时  
  6.     K = 3;  % 这里可以调整k值  
  7.     trainLabels = [];  
  8.     direct = mfilename('fullpath');%  
  9.     traindirect = strrep(direct,'handWritingTest','trainingDigits'); %trainingDigits  
  10.    %获得路径  
  11.     traindirfile = dir(fullfile(traindirect,'*.txt'));%提取后缀名.txt  
  12.     traindircell = struct2cell(traindirfile)';  
  13.     trainfilenames = traindircell(:,1);  
  14.     trainfileNums = length(trainfilenames);  
  15.     trainMat = zeros(trainfileNums,1024);  
  16.     for i = 1:trainfileNums  
  17.         fileNameStr = trainfilenames(i);  
  18.         str = deblank(fileNameStr);  
  19.         s = regexp(str,'\.','split'); %  
  20.         fileStr = s{1}(1);  
  21.         classNumStr =  regexp(fileStr,'\_','split');  
  22.         trainLabels(i)=str2num(char(classNumStr{1}(1))); %得到类别 0 - 9   
  23.         filePath = strcat(traindirect,'\',fileNameStr); %文件路径  
  24.         trainMat(i,:) = img2vector(filePath);%处理文件 获得向量  
  25.     end  
  26.       
  27.     %测试样本  
  28.     direct = mfilename('fullpath');  
  29.     testdirect = strrep(direct,'handWritingTest','testDigits');%testDigits  
  30.     testdirfile = dir(fullfile(testdirect,'*.txt'));  
  31.     testdircell = struct2cell(testdirfile)';  
  32.     testfilenames = testdircell(:,1);  
  33.     testfileNums = length(testfilenames);  
  34.     errorcount = 0;  
  35.     for j = 1:testfileNums  
  36.         fileNameStr = testfilenames(j);  
  37.         str = deblank(fileNameStr);  
  38.         s = regexp(str,'\.','split');  
  39.         fileStr = s{1}(1);  
  40.         classNumStr =  regexp(fileStr,'\_','split');  
  41.         testLabel = str2num(char(classNumStr{1}(1))); %得到类别 0 - 9   
  42.         filePath = strcat(testdirect,'\',fileNameStr);  
  43.         testVector = img2vector(filePath);  
  44.         classifyRet = classify(testVector,trainMat,trainLabels,K);  
  45.         if(classifyRet ~= testLabel)  
  46.             errorcount = errorcount + 1;  
  47.             fprintf('test result:  %d,    real result:  %d ,    here error!!! \n',classifyRet,testLabel);  
  48.         else  
  49.             fprintf('test result:  %d,    real result:  %d \n',classifyRet,testLabel);  
  50.         end  
  51.     end  
  52.     lastTime = num2str(toc);  
  53.     fprintf('\n the sum numbers of errors :  %d ',errorcount);  
  54.     fprintf('\n the total error rate :  %f  ' ,(errorcount / testfileNums));  
  55.     fprintf('\n total time :    %f',lastTime);  
  56. end  
  57.   
  58. %  
  59. %KNN算法 classify(test,dataSet,labels,k)  
  60. %四个参数:test用于分类的输入向量;输入的训练样本集为dataSet;  
  61. %标签向量为labels; k 表示用于选择最近邻居的数目;  
  62. %  
  63.   
  64. function maxClass = classify(test,dataSet,labels,k)  
  65.     [dataRow,dataCol] = size(dataSet);%dataRow:样本个数;dataCol:特征  
  66.     %求距离 test 与样本数据之间的距离   这里为欧式距离  
  67.     diffMat = dataSet;  
  68.     for i = 1:dataRow  
  69.         diffMat(i,:) = diffMat(i,:) - test;   
  70.     end  
  71.     sqdiffMat = diffMat.^2;  
  72.     sqDistances = sum(sqdiffMat,2).^(0.5);  
  73.     [p,q] = sort(sqDistances);  %p代表要排序的数,q代表要排序的数原来对应的索引  
  74.     %通过k  来求最邻居的前k 个数据,然后找的在这些数据中类别最多的  
  75.     classCount=zeros(10,1);  
  76.     class = [];  
  77.     for j = 1:k  
  78.         tempLabel = labels(q(j));  
  79.         class(j) = tempLabel;%没用到  
  80.         classCount(tempLabel+1) = classCount(tempLabel+1)+1;  
  81.     end  
  82.     [r,s] = max(classCount);  
  83.     maxClass = s - 1;  %返回 相似个数最多的 那个类  
  84. end  
  85.   
  86. %  
  87. %将32*32的二进制图形矩阵转换为1*1024的向量  
  88. %  
  89. function retVector = img2vector(fileName)  
  90.    fileName = char(fileName);  
  91.    tempVector = [];  
  92.     % 读文件  
  93.    fileData = textread(fileName,'%s');  
  94.    fileData = char(fileData);%读取文件,并将文件转换矩阵的格式  
  95.    temp = fileData(:)';  
  96.    for i = 1 : length(temp)  
  97.        tempVector(i) = str2num(temp(i));  
  98.    end  
  99.    retVector = tempVector;  
  100. end 
0 0
原创粉丝点击