python数据挖掘实践第一章 KNN算法,以及算法的实现

来源:互联网 发布:监控客户端软件 编辑:程序博客网 时间:2024/05/16 09:09

KNN思路:

1.利用测试集和训练数据的相似性(目标距离的远近),取前K个

2.计算前K个分类的数据多少

3.选择最多的一个分目标分类


算法思路:

   不考虑时间复杂度

1、读取测试集,或者自动生成

2、遍历测试集

3、与训练集的每一行,进行计算距离(采用两点的距离公式),同时保存每一行计算的距离,和分类的值(生成一个新的矩阵或者集合)

4、选取前K个(这里要对距离和分类进行排序,排序算法)

5、要对前K个(已经排好离的分类,进行分组;并对C的数量进行排序,选取最多的就是目标)



下面是python 2.7的实现

from numpy import * 


import operator


from os import listdir
import pylab as pl 


from matplotlib import pyplot as plt


def file2Matrix(filename):
    fr = open(filename)
    numberOfLines = len(fr.readlines())         #get the number of lines in the file
    returnMat = zeros((numberOfLines,3))        #prepare matrix to return
    classLabelVector = []                       #prepare labels return   
    fr = open(filename)
    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat,classLabelVector


##
def createDataSet():
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])


    label =['A','A','B','B']


    return label,group




def  classify0(inX,dataSet,labels,k):
    dataSetSize = dataSet.shape[0]


    print 'dataSet.shape[1]',dataSet.shape[1]
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()  
    classCount={}


    for a in range(k):
        voteIlabel=labels[sortedDistIndicies[a]]
        classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
    sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]




def classify01(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]




def autoNorm(datamingDataMat):
    minValue = datamingDataMat.min(0)
    maxValue = datamingDataMat.max(0)
    #print minValue
    #print maxValue


    range = maxValue-minValue


    print range
    #print shape(datamingDataMat)


    normDataSet = zeros(shape(datamingDataMat))
    #print  normDataSet
    m = datamingDataMat.shape[0]


    normDataSet = datamingDataMat-tile(minValue,(m,1))


    normDataSet = datamingDataMat/tile(range,(m,1))


    return normDataSet ,range,minValue
    
if __name__ == "__main__":
    label,group = createDataSet()
    print label
    ##pl.plot(group,'o')
    ##pl.show()
    #classify0([0,0],group,label,3)


    #读文件


    #datamingDataMat,datingLabels = file2Matrix('E:\\数据挖掘资料\\《机器学习实战》源代码\\machinelearninginaction\\Ch02\\datingTestSet2.txt')
    #
##    print len(datingLabels)
##    print datamingDataMat




##    fig = plt.figure()
##    ax = fig.add_subplot(111)
##    ax.scatter(datamingDataMat[:,0],datamingDataMat[:,1],
##              15.0*array(datingLabels),15.0*array(datingLabels))
##    plt.show()
##    normMat,ranges,minValue =autoNorm(datamingDataMat)
##
##    #print 'normorize:',normMat,'ranges:',ranges
##
##    m = normMat.shape[0]
##
##    print m
    hoRatio = 0.10      #hold out 10%
    datingDataMat,datingLabels = file2Matrix('E:\\数据挖掘资料\\《机器学习实战》源代码\\machinelearninginaction\\Ch02\\datingTestSet2.txt')       #load data setfrom file
    normMat, ranges, minVals = autoNorm(datingDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
        print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
        if (classifierResult != datingLabels[i]): errorCount += 1.0
    print "the total error rate is: %f" % (errorCount/float(numTestVecs))
    print errorCount

---------------------------------------------------------------------------------------
    


以下是java 代码的实现

public class tztKNN {



/**
* set unique  coparator function,distance biggest,
*/
private static Comparator<KNNNode> comparator = new Comparator<KNNNode>() {


@Override
public int compare(KNNNode o1, KNNNode o2) {
// TODO Auto-generated method stub
if(o1.getDistance() >=o2.getDistance()){
return -1;
} else{
return 1;
}

}


};


public static void main(String[] args) {
// TODO Auto-generated method stub
   //第一加载文件
String fileName="E:\\数据挖掘资料\\《机器学习实战》源代码\\datingTestSet2.txt";

List<List<Float>>list = ReadFile(fileName);
/* System.out.println(list.get(0).toString());
System.out.println(list.get(1).toString());*/

List<Float>testlist =new ArrayList();
testlist.add(40920f);
testlist.add(8.326976f);
testlist.add(1.673904f);

//第二步
//可以直接计算距离
PriorityQueue<KNNNode> result = new PriorityQueue<KNNNode>();

result=Distances(testlist,list,3);
System.out.println("---------result-------");
//在这里计算出现最多的就是目标分类

String classif_re=getMostClass(result);


System.out.println(classif_re);
}


private static  String getMostClass(PriorityQueue<KNNNode> result) {
Map<String,Integer> classCount = new HashMap<>();
for(int i = 0; i<5; i++){
KNNNode node = result.poll();
String c = node.getC();
System.out.println("c:"+c);
if(classCount.containsKey(c)){
classCount.put(c, classCount.get(c)+1);
}
else{
classCount.put(c, 1);
}
}



        List<Entry<String,Integer>> list =new ArrayList<Entry<String,Integer>>(classCount.entrySet());
        //最后通过Collections.sort(List l, Comparator c)方法来进行排序,代码如下:


         Collections.sort(list, new Comparator<Map.Entry<String, Integer>>() {
                      public int compare(Map.Entry<String, Integer> o1,
                       Map.Entry<String, Integer> o2) {
                     return (o2.getValue() - o1.getValue());
           }
          });
         return list.get(0).getKey();
// TODO Auto-generated method stub
/*String re="";
Map<String, Integer> classCount = new HashMap<String, Integer>() ;
    int pqsize = 5;
    for (int i =0; i <pqsize; i++){
    KNNNode  node = result.remove();
    String c = node.getC();
    if (classCount.containsKey(c)){
    classCount.put(c, classCount.get(c)+1);
    }else{
    classCount.put(c, 1);
    }
    }
   
    int maxIndex =-1;
    int maxCount=0;
    Object[]classes = classCount.keySet().toArray();
    for (int i = 0; i<classes.length; i++){
    if (classCount.get(classes[i])>maxCount){
    maxIndex = i;
    maxCount = classCount.get(classes[i]);
    }
    }*/
    //return (String) classes[maxIndex];
}


private static PriorityQueue<KNNNode> Distances(List<Float> testlist, List<List<Float>> list,int k) {
// TODO Auto-generated method stub

double distance = 0.00;    
PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k,comparator);
// 要存储每个元组的顺序,距离,和他们的分类

for(int i = 0; i<list.size(); i++){
List<Float> t = list.get(i);
String c = t.get(t.size() - 1).toString();

 for (int j= 0; j < testlist.size(); j++) {
             float test0 = testlist.get(j);
             float base0 = t.get(j);
             distance = (test0- base0)*(test0- base0);
    }  
 System.out.println("the:"+i+" ci:"+"c:"+c+" distace:"+distance);
 KNNNode node = new KNNNode(i,distance,c); 
 pq.add(node);
}


 
   return pq;



}


private static List<List<Float>> ReadFile(String fileName) {
// TODO Auto-generated method stub
List<List<Float>> datasets = new ArrayList<>();
try {
File file = new File(fileName);
if(file.exists()){
InputStreamReader reader=new InputStreamReader(new FileInputStream(file),"GBK" );
BufferedReader bufferedReader = new BufferedReader(reader);

String lineTxt = "";

while((lineTxt =bufferedReader.readLine())!=null){
//System.out.println(lineTxt);
//分离分行数据
ArrayList<Float> aa = new ArrayList<>();
String[] shuxing =lineTxt.split("");

for (String sx :shuxing){
//System.out.println(sx);
aa.add(Float.parseFloat(sx));
}
datasets.add(aa);
}

  }else{
           System.out.println("找不到指定的文件");
       }


} catch (Exception e) {
// TODO: handle exception
System.out.println("读取文件内容出错");
           e.printStackTrace();
}

return datasets;
}


}
0 0