数据挖掘10大算法(6)-K最近邻(KNN)算法的实现(java和python版)

来源:互联网 发布:淘宝大韩泡泡糖是高仿 编辑:程序博客网 时间:2024/05/16 04:33

数据挖掘-K最近邻(KNN)算法的实现(java和python版)

KNN算法基础思想前面文章可以参考,这里主要讲解java和python的两种简单实现,也主要是理解简单的思想。

http://blog.csdn.net/u011067360/article/details/23941577

python版本:

这里实现一个手写识别算法,这里只简单识别0~9熟悉,在上篇文章中也展示了手写识别的应用,可以参考:机器学习与数据挖掘-logistic回归及手写识别实例的实现

输入:每个手写数字已经事先处理成32*32的二进制文本,存储为txt文件。0~9每个数字都有10个训练样本,5个测试样本。训练样本集如下图:左边是文件目录,右边是其中一个文件打开显示的结果,看着像1,这里有0~9,每个数字都有是个样本来作为训练集。




第一步:将每个txt文本转化为一个向量,即32*32的数组转化为1*1024的数组,这个1*1024的数组用机器学习的术语来说就是特征向量。

[python] view plaincopy在CODE上查看代码片派生到我的代码片
  1. <span style="font-size:14px;">def img2vector(filename):  
  2.     returnVect = zeros((1,1024))  
  3.     fr = open(filename)  
  4.     for i in range(32):  
  5.         lineStr = fr.readline()  
  6.         for j in range(32):  
  7.             returnVect[0,32*i+j] = int(lineStr[j])  
  8.     return returnVect</span>  

第二步:训练样本中有10*10个图片,可以合并成一个100*1024的矩阵,每一行对应一个图片,也就是一个txt文档。

[python] view plaincopy在CODE上查看代码片派生到我的代码片
  1. def handwritingClassTest():  
  2.   
  3.     hwLabels = []  
  4.     trainingFileList = listdir('trainingDigits')    
  5.     print trainingFileList          
  6.     m = len(trainingFileList)  
  7.     trainingMat = zeros((m,1024))  
  8.     for i in range(m):  
  9.         fileNameStr = trainingFileList[i]            
  10.         fileStr = fileNameStr.split('.')[0]  
  11.         classNumStr = int(fileStr.split('_')[0])   
  12.         hwLabels.append(classNumStr)  
  13.         #print hwLabels  
  14.         #print fileNameStr     
  15.         trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)  
  16.         #print trainingMat[i,:]   
  17.         #print len(trainingMat[i,:])  
  18.        
  19.     testFileList = listdir('testDigits')         
  20.     errorCount = 0.0  
  21.     mTest = len(testFileList)  
  22.     for i in range(mTest):  
  23.         fileNameStr = testFileList[i]  
  24.         fileStr = fileNameStr.split('.')[0]       
  25.         classNumStr = int(fileStr.split('_')[0])  
  26.         vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)  
  27.         classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)  
  28.         print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)  
  29.         if (classifierResult != classNumStr): errorCount += 1.0  
  30.     print "\nthe total number of errors is: %d" % errorCount  
  31.     print "\nthe total error rate is: %f" % (errorCount/float(mTest))  

第三步:测试样本中有10*5个图片,同样的,对于测试图片,将其转化为1*1024的向量,然后计算它与训练样本中各个图片的“距离”(这里两个向量的距离采用欧式距离),然后对距离排序,选出较小的前k个,因为这k个样本来自训练集,是已知其代表的数字的,所以被测试图片所代表的数字就可以确定为这k个中出现次数最多的那个数字。

[python] view plaincopy在CODE上查看代码片派生到我的代码片
  1. def classify0(inX, dataSet, labels, k):  
  2.     dataSetSize = dataSet.shape[0]  
  3.     #tile(A,(m,n))     
  4.     print dataSet  
  5.     print "----------------"  
  6.     print tile(inX, (dataSetSize,1))  
  7.     print "----------------"  
  8.     diffMat = tile(inX, (dataSetSize,1)) - dataSet        
  9.     print diffMat  
  10.     sqDiffMat = diffMat**2  
  11.     sqDistances = sqDiffMat.sum(axis=1)                    
  12.     distances = sqDistances**0.5  
  13.     sortedDistIndicies = distances.argsort()              
  14.     classCount={}                                        
  15.     for i in range(k):  
  16.         voteIlabel = labels[sortedDistIndicies[i]]  
  17.         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1  
  18.     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)  
  19.     return sortedClassCount[0][0]  


全部实现代码:
[python] view plaincopy在CODE上查看代码片派生到我的代码片
  1. #-*-coding:utf-8-*-  
  2. from numpy import *  
  3. import operator  
  4. from os import listdir  
  5.   
  6. def classify0(inX, dataSet, labels, k):  
  7.     dataSetSize = dataSet.shape[0]  
  8.     #tile(A,(m,n))     
  9.     print dataSet  
  10.     print "----------------"  
  11.     print tile(inX, (dataSetSize,1))  
  12.     print "----------------"  
  13.     diffMat = tile(inX, (dataSetSize,1)) - dataSet        
  14.     print diffMat  
  15.     sqDiffMat = diffMat**2  
  16.     sqDistances = sqDiffMat.sum(axis=1)                    
  17.     distances = sqDistances**0.5  
  18.     sortedDistIndicies = distances.argsort()              
  19.     classCount={}                                        
  20.     for i in range(k):  
  21.         voteIlabel = labels[sortedDistIndicies[i]]  
  22.         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1  
  23.     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)  
  24.     return sortedClassCount[0][0]  
  25.   
  26. def img2vector(filename):  
  27.     returnVect = zeros((1,1024))  
  28.     fr = open(filename)  
  29.     for i in range(32):  
  30.         lineStr = fr.readline()  
  31.         for j in range(32):  
  32.             returnVect[0,32*i+j] = int(lineStr[j])  
  33.     return returnVect  
  34.   
  35. def handwritingClassTest():  
  36.   
  37.     hwLabels = []  
  38.     trainingFileList = listdir('trainingDigits')    
  39.     print trainingFileList          
  40.     m = len(trainingFileList)  
  41.     trainingMat = zeros((m,1024))  
  42.     for i in range(m):  
  43.         fileNameStr = trainingFileList[i]            
  44.         fileStr = fileNameStr.split('.')[0]  
  45.         classNumStr = int(fileStr.split('_')[0])   
  46.         hwLabels.append(classNumStr)  
  47.         #print hwLabels  
  48.         #print fileNameStr     
  49.         trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)  
  50.         #print trainingMat[i,:]   
  51.         #print len(trainingMat[i,:])  
  52.        
  53.     testFileList = listdir('testDigits')         
  54.     errorCount = 0.0  
  55.     mTest = len(testFileList)  
  56.     for i in range(mTest):  
  57.         fileNameStr = testFileList[i]  
  58.         fileStr = fileNameStr.split('.')[0]       
  59.         classNumStr = int(fileStr.split('_')[0])  
  60.         vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)  
  61.         classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)  
  62.         print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)  
  63.         if (classifierResult != classNumStr): errorCount += 1.0  
  64.     print "\nthe total number of errors is: %d" % errorCount  
  65.     print "\nthe total error rate is: %f" % (errorCount/float(mTest))  
  66.       
  67. handwritingClassTest()      

运行结果:源码文章尾可下载



java版本

先看看训练集和测试集:

训练集:


测试集:



训练集最后一列代表分类(0或者1)


代码实现:

 KNN算法主体类:

[java] view plaincopy在CODE上查看代码片派生到我的代码片
  1. package Marchinglearning.knn2;  
  2.   
  3. import java.util.ArrayList;  
  4. import java.util.Comparator;  
  5. import java.util.HashMap;  
  6. import java.util.List;  
  7. import java.util.Map;  
  8. import java.util.PriorityQueue;  
  9.   
  10. /** 
  11.  * KNN算法主体类 
  12.  */  
  13. public class KNN {  
  14.     /** 
  15.      * 设置优先级队列的比较函数,距离越大,优先级越高 
  16.      */  
  17.     private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {  
  18.         public int compare(KNNNode o1, KNNNode o2) {  
  19.             if (o1.getDistance() >= o2.getDistance()) {  
  20.                 return 1;  
  21.             } else {  
  22.                 return 0;  
  23.             }  
  24.         }  
  25.     };  
  26.     /** 
  27.      * 获取K个不同的随机数 
  28.      * @param k 随机数的个数 
  29.      * @param max 随机数最大的范围 
  30.      * @return 生成的随机数数组 
  31.      */  
  32.     public List<Integer> getRandKNum(int k, int max) {  
  33.         List<Integer> rand = new ArrayList<Integer>(k);  
  34.         for (int i = 0; i < k; i++) {  
  35.             int temp = (int) (Math.random() * max);  
  36.             if (!rand.contains(temp)) {  
  37.                 rand.add(temp);  
  38.             } else {  
  39.                 i--;  
  40.             }  
  41.         }  
  42.         return rand;  
  43.     }  
  44.     /** 
  45.      * 计算测试元组与训练元组之前的距离 
  46.      * @param d1 测试元组 
  47.      * @param d2 训练元组 
  48.      * @return 距离值 
  49.      */  
  50.     public double calDistance(List<Double> d1, List<Double> d2) {  
  51.         System.out.println("d1:"+d1+",d2"+d2);  
  52.         double distance = 0.00;  
  53.         for (int i = 0; i < d1.size(); i++) {  
  54.             distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));  
  55.         }  
  56.         return distance;  
  57.     }  
  58.     /** 
  59.      * 执行KNN算法,获取测试元组的类别 
  60.      * @param datas 训练数据集 
  61.      * @param testData 测试元组 
  62.      * @param k 设定的K值 
  63.      * @return 测试元组的类别 
  64.      */  
  65.     public String knn(List<List<Double>> datas, List<Double> testData, int k) {  
  66.         PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);  
  67.         List<Integer> randNum = getRandKNum(k, datas.size());  
  68.         System.out.println("randNum:"+randNum.toString());  
  69.         for (int i = 0; i < k; i++) {  
  70.             int index = randNum.get(i);  
  71.             List<Double> currData = datas.get(index);  
  72.             String c = currData.get(currData.size() - 1).toString();  
  73.             System.out.println("currData:"+currData+",c:"+c+",testData"+testData);  
  74.             //计算测试元组与训练元组之前的距离  
  75.             KNNNode node = new KNNNode(index, calDistance(testData, currData), c);  
  76.             pq.add(node);  
  77.         }  
  78.         for (int i = 0; i < datas.size(); i++) {  
  79.             List<Double> t = datas.get(i);  
  80.             System.out.println("testData:"+testData);  
  81.             System.out.println("t:"+t);  
  82.             double distance = calDistance(testData, t);  
  83.             System.out.println("distance:"+distance);  
  84.             KNNNode top = pq.peek();  
  85.             if (top.getDistance() > distance) {  
  86.                 pq.remove();  
  87.                 pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));  
  88.             }  
  89.         }  
  90.           
  91.         return getMostClass(pq);  
  92.     }  
  93.     /** 
  94.      * 获取所得到的k个最近邻元组的多数类 
  95.      * @param pq 存储k个最近近邻元组的优先级队列 
  96.      * @return 多数类的名称 
  97.      */  
  98.     private String getMostClass(PriorityQueue<KNNNode> pq) {  
  99.         Map<String, Integer> classCount = new HashMap<String, Integer>();  
  100.         for (int i = 0; i < pq.size(); i++) {  
  101.             KNNNode node = pq.remove();  
  102.             String c = node.getC();  
  103.             if (classCount.containsKey(c)) {  
  104.                 classCount.put(c, classCount.get(c) + 1);  
  105.             } else {  
  106.                 classCount.put(c, 1);  
  107.             }  
  108.         }  
  109.         int maxIndex = -1;  
  110.         int maxCount = 0;  
  111.         Object[] classes = classCount.keySet().toArray();  
  112.         for (int i = 0; i < classes.length; i++) {  
  113.             if (classCount.get(classes[i]) > maxCount) {  
  114.                 maxIndex = i;  
  115.                 maxCount = classCount.get(classes[i]);  
  116.             }  
  117.         }  
  118.         return classes[maxIndex].toString();  
  119.     }  
  120. }  

 KNN结点类,用来存储最近邻的k个元组相关的信息

[java] view plaincopy在CODE上查看代码片派生到我的代码片
  1. package Marchinglearning.knn2;  
  2. /** 
  3.  * KNN结点类,用来存储最近邻的k个元组相关的信息 
  4.  */  
  5. public class KNNNode {  
  6.     private int index; // 元组标号  
  7.     private double distance; // 与测试元组的距离  
  8.     private String c; // 所属类别  
  9.     public KNNNode(int index, double distance, String c) {  
  10.         super();  
  11.         this.index = index;  
  12.         this.distance = distance;  
  13.         this.c = c;  
  14.     }  
  15.       
  16.       
  17.     public int getIndex() {  
  18.         return index;  
  19.     }  
  20.     public void setIndex(int index) {  
  21.         this.index = index;  
  22.     }  
  23.     public double getDistance() {  
  24.         return distance;  
  25.     }  
  26.     public void setDistance(double distance) {  
  27.         this.distance = distance;  
  28.     }  
  29.     public String getC() {  
  30.         return c;  
  31.     }  
  32.     public void setC(String c) {  
  33.         this.c = c;  
  34.     }  
  35. }  

KNN算法测试类

[java] view plaincopy在CODE上查看代码片派生到我的代码片
  1. package Marchinglearning.knn2;  
  2. import java.io.BufferedReader;  
  3. import java.io.File;  
  4. import java.io.FileReader;  
  5. import java.util.ArrayList;  
  6. import java.util.List;  
  7. /** 
  8.  * KNN算法测试类 
  9.  */  
  10. public class TestKNN {  
  11.       
  12.     /** 
  13.      * 从数据文件中读取数据 
  14.      * @param datas 存储数据的集合对象 
  15.      * @param path 数据文件的路径 
  16.      */  
  17.     public void read(List<List<Double>> datas, String path){  
  18.         try {  
  19.             BufferedReader br = new BufferedReader(new FileReader(new File(path)));  
  20.             String data = br.readLine();  
  21.             List<Double> l = null;  
  22.             while (data != null) {  
  23.                 String t[] = data.split(" ");  
  24.                 l = new ArrayList<Double>();  
  25.                 for (int i = 0; i < t.length; i++) {  
  26.                     l.add(Double.parseDouble(t[i]));  
  27.                 }  
  28.                 datas.add(l);  
  29.                 data = br.readLine();  
  30.             }  
  31.         } catch (Exception e) {  
  32.             e.printStackTrace();  
  33.         }  
  34.     }  
  35.       
  36.     /** 
  37.      * 程序执行入口 
  38.      * @param args 
  39.      */  
  40.     public static void main(String[] args) {  
  41.         TestKNN t = new TestKNN();  
  42.         String datafile = new File("").getAbsolutePath() + File.separator +"knndata2"+File.separator + "datafile.data";  
  43.         String testfile = new File("").getAbsolutePath() + File.separator +"knndata2"+File.separator +"testfile.data";  
  44.         System.out.println("datafile:"+datafile);  
  45.         System.out.println("testfile:"+testfile);  
  46.         try {  
  47.             List<List<Double>> datas = new ArrayList<List<Double>>();  
  48.             List<List<Double>> testDatas = new ArrayList<List<Double>>();  
  49.             t.read(datas, datafile);  
  50.             t.read(testDatas, testfile);  
  51.             KNN knn = new KNN();  
  52.             for (int i = 0; i < testDatas.size(); i++) {  
  53.                 List<Double> test = testDatas.get(i);  
  54.                 System.out.print("测试元组: ");  
  55.                 for (int j = 0; j < test.size(); j++) {  
  56.                     System.out.print(test.get(j) + " ");  
  57.                 }  
  58.                 System.out.print("类别为: ");  
  59.                 System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));  
  60.             }  
  61.         } catch (Exception e) {  
  62.             e.printStackTrace();  
  63.         }  
  64.     }  
  65. }  

运行结果为:



资源下载:

python版本下载

java版本下载



0 0
原创粉丝点击