最邻近规则分类(K-Nearest Neighbor)KNN算法应用

来源:互联网 发布:西门子tia博途软件 编辑:程序博客网 时间:2024/05/17 04:27

最邻近规则分类(K-Nearest Neighbor)KNN算法应用

1.Iris数据集介绍

调用ython的机器学习库sklearn实现虹膜分类

Iris数据包含150条样本记录,分剐取自三种不同的鸢尾属植物setosa、versic010r和virginica的花朵样本,每一
类各50条记录,其中每条记录有4个属性:萼片长度(sepal length)、萼片宽度sepalwidth)、花瓣长度(petal length)和花瓣宽度(petal width)。
这是一个极其简单的域。
数据集(irisdata.txt)下载地址:https://pan.baidu.com/s/1bpHCyHD
类别:
Iris setosa, Iris versicolor, Iris virginica.

2. 利用Python的机器学习库sklearn: iris.py

    # -*- coding:utf-8 -*-    from sklearn import neighbors #导入包含邻近算法的模块    from sklearn import datasets #导入自带的数据集模块,导入后我们就可以利用自带的数据集    knn = neighbors.KNeighborsClassifier() #调用KNN的分类器    iris = datasets.load_iris()#load_iris()方法返回iris的数据库    # save data    # f = open("iris.data.csv", 'wb')    # f.write(str(iris))    # f.close()    print iris #iris数据集就是这个样子的    # {'target_names': array(['setosa', 'versicolor', 'virginica'],    #   dtype='|S10'), 'data': array([[ 5.1,  3.5,  1.4,  0.2],    #[ 4.9,  3. ,  1.4,  0.2],    #[ 4.7,  3.2,  1.3,  0.2],    #[ 4.6,  3.1,  1.5,  0.2],    #[ 5. ,  3.6,  1.4,  0.2],    #[ 5.4,  3.9,  1.7,  0.4],    #[ 4.6,  3.4,  1.4,  0.3],    #[ 5. ,  3.4,  1.5,  0.2],    #[ 4.4,  2.9,  1.4,  0.2],    #[ 4.9,  3.1,  1.5,  0.1],    #[ 5.4,  3.7,  1.5,  0.2],    #[ 4.8,  3.4,  1.6,  0.2],    #[ 4.8,  3. ,  1.4,  0.1],    #[ 4.3,  3. ,  1.1,  0.1],    #[ 5.8,  4. ,  1.2,  0.2],    #[ 5.7,  4.4,  1.5,  0.4],    #[ 5.4,  3.9,  1.3,  0.4],    #[ 5.1,  3.5,  1.4,  0.3],    #[ 5.7,  3.8,  1.7,  0.3],    #[ 5.1,  3.8,  1.5,  0.3],    #[ 5.4,  3.4,  1.7,  0.2],    #[ 5.1,  3.7,  1.5,  0.4],    #[ 4.6,  3.6,  1. ,  0.2],    #[ 5.1,  3.3,  1.7,  0.5],    #[ 4.8,  3.4,  1.9,  0.2],    #[ 5. ,  3. ,  1.6,  0.2],    #[ 5. ,  3.4,  1.6,  0.4],    #[ 5.2,  3.5,  1.5,  0.2],    #[ 5.2,  3.4,  1.4,  0.2],    #[ 4.7,  3.2,  1.6,  0.2],    #[ 4.8,  3.1,  1.6,  0.2],    #[ 5.4,  3.4,  1.5,  0.4],    #[ 5.2,  4.1,  1.5,  0.1],    #[ 5.5,  4.2,  1.4,  0.2],    #[ 4.9,  3.1,  1.5,  0.1],    #[ 5. ,  3.2,  1.2,  0.2],    #[ 5.5,  3.5,  1.3,  0.2],    #[ 4.9,  3.1,  1.5,  0.1],    #[ 4.4,  3. ,  1.3,  0.2],    #[ 5.1,  3.4,  1.5,  0.2],    #[ 5. ,  3.5,  1.3,  0.3],    #[ 4.5,  2.3,  1.3,  0.3],    #[ 4.4,  3.2,  1.3,  0.2],    #[ 5. ,  3.5,  1.6,  0.6],    #[ 5.1,  3.8,  1.9,  0.4],    #[ 4.8,  3. ,  1.4,  0.3],    #[ 5.1,  3.8,  1.6,  0.2],    #[ 4.6,  3.2,  1.4,  0.2],    #[ 5.3,  3.7,  1.5,  0.2],    #[ 5. ,  3.3,  1.4,  0.2],    #[ 7. ,  3.2,  4.7,  1.4],    #[ 6.4,  3.2,  4.5,  1.5],    #[ 6.9,  3.1,  4.9,  1.5],    #[ 5.5,  2.3,  4. ,  1.3],    #[ 6.5,  2.8,  4.6,  1.5],    #[ 5.7,  2.8,  4.5,  1.3],    #[ 6.3,  3.3,  4.7,  1.6],    #[ 4.9,  2.4,  3.3,  1. ],    #[ 6.6,  2.9,  4.6,  1.3],    #[ 5.2,  2.7,  3.9,  1.4],    #[ 5. ,  2. ,  3.5,  1. ],    #[ 5.9,  3. ,  4.2,  1.5],    #[ 6. ,  2.2,  4. ,  1. ],    #[ 6.1,  2.9,  4.7,  1.4],    #[ 5.6,  2.9,  3.6,  1.3],    #[ 6.7,  3.1,  4.4,  1.4],    #[ 5.6,  3. ,  4.5,  1.5],    #[ 5.8,  2.7,  4.1,  1. ],    #[ 6.2,  2.2,  4.5,  1.5],    #[ 5.6,  2.5,  3.9,  1.1],    #[ 5.9,  3.2,  4.8,  1.8],    #[ 6.1,  2.8,  4. ,  1.3],    #[ 6.3,  2.5,  4.9,  1.5],    #[ 6.1,  2.8,  4.7,  1.2],    #[ 6.4,  2.9,  4.3,  1.3],    #[ 6.6,  3. ,  4.4,  1.4],    #[ 6.8,  2.8,  4.8,  1.4],    #[ 6.7,  3. ,  5. ,  1.7],    #[ 6. ,  2.9,  4.5,  1.5],    #[ 5.7,  2.6,  3.5,  1. ],    #[ 5.5,  2.4,  3.8,  1.1],    #[ 5.5,  2.4,  3.7,  1. ],    #[ 5.8,  2.7,  3.9,  1.2],    #[ 6. ,  2.7,  5.1,  1.6],    #[ 5.4,  3. ,  4.5,  1.5],    #[ 6. ,  3.4,  4.5,  1.6],    #[ 6.7,  3.1,  4.7,  1.5],    #[ 6.3,  2.3,  4.4,  1.3],    #[ 5.6,  3. ,  4.1,  1.3],    #[ 5.5,  2.5,  4. ,  1.3],    #[ 5.5,  2.6,  4.4,  1.2],    #[ 6.1,  3. ,  4.6,  1.4],    #[ 5.8,  2.6,  4. ,  1.2],    #[ 5. ,  2.3,  3.3,  1. ],    #[ 5.6,  2.7,  4.2,  1.3],    #[ 5.7,  3. ,  4.2,  1.2],    #[ 5.7,  2.9,  4.2,  1.3],    #[ 6.2,  2.9,  4.3,  1.3],    #[ 5.1,  2.5,  3. ,  1.1],    #[ 5.7,  2.8,  4.1,  1.3],    #[ 6.3,  3.3,  6. ,  2.5],    #[ 5.8,  2.7,  5.1,  1.9],    #[ 7.1,  3. ,  5.9,  2.1],    #[ 6.3,  2.9,  5.6,  1.8],    #[ 6.5,  3. ,  5.8,  2.2],    #[ 7.6,  3. ,  6.6,  2.1],    #[ 4.9,  2.5,  4.5,  1.7],    #[ 7.3,  2.9,  6.3,  1.8],    #[ 6.7,  2.5,  5.8,  1.8],    #[ 7.2,  3.6,  6.1,  2.5],    #[ 6.5,  3.2,  5.1,  2. ],    #[ 6.4,  2.7,  5.3,  1.9],    #[ 6.8,  3. ,  5.5,  2.1],    #[ 5.7,  2.5,  5. ,  2. ],    #[ 5.8,  2.8,  5.1,  2.4],    #[ 6.4,  3.2,  5.3,  2.3],    #[ 6.5,  3. ,  5.5,  1.8],    #[ 7.7,  3.8,  6.7,  2.2],    #[ 7.7,  2.6,  6.9,  2.3],    #[ 6. ,  2.2,  5. ,  1.5],    #[ 6.9,  3.2,  5.7,  2.3],    #[ 5.6,  2.8,  4.9,  2. ],    #[ 7.7,  2.8,  6.7,  2. ],    #[ 6.3,  2.7,  4.9,  1.8],    #[ 6.7,  3.3,  5.7,  2.1],    #[ 7.2,  3.2,  6. ,  1.8],    #[ 6.2,  2.8,  4.8,  1.8],    #[ 6.1,  3. ,  4.9,  1.8],    #[ 6.4,  2.8,  5.6,  2.1],    #[ 7.2,  3. ,  5.8,  1.6],    #[ 7.4,  2.8,  6.1,  1.9],    #[ 7.9,  3.8,  6.4,  2. ],    #[ 6.4,  2.8,  5.6,  2.2],    #[ 6.3,  2.8,  5.1,  1.5],    #[ 6.1,  2.6,  5.6,  1.4],    #[ 7.7,  3. ,  6.1,  2.3],    #[ 6.3,  3.4,  5.6,  2.4],    #[ 6.4,  3.1,  5.5,  1.8],    #[ 6. ,  3. ,  4.8,  1.8],    #[ 6.9,  3.1,  5.4,  2.1],    #[ 6.7,  3.1,  5.6,  2.4],    #[ 6.9,  3.1,  5.1,  2.3],    #[ 5.8,  2.7,  5.1,  1.9],    #[ 6.8,  3.2,  5.9,  2.3],    #[ 6.7,  3.3,  5.7,  2.5],    #[ 6.7,  3. ,  5.2,  2.3],    #[ 6.3,  2.5,  5. ,  1.9],    #[ 6.5,  3. ,  5.2,  2. ],    #[ 6.2,  3.4,  5.4,  2.3],    #[ 5.9,  3. ,  5.1,  1.8]]), 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,    #0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,    #0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,    #1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,    #1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,    #2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,    #2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), 'DESCR': 'Iris Plants Database\n====================\n\nNotes\n-----\nData Set Characteristics:\n:Number of Instances: 150 (50 in each of three classes)\n:Number of Attributes: 4 numeric, predictive attributes and the class\n:Attribute Information:\n- sepal length in cm\n- sepal width in cm\n- petal length in cm\n- petal width in cm\n- class:\n- Iris-Setosa\n- Iris-Versicolour\n- Iris-Virginica\n:Summary Statistics:\n\n============== ==== ==== ======= ===== ====================\nMin  Max   MeanSD   Class Correlation\n============== ==== ==== ======= ===== ====================\nsepal length:   4.3  7.9   5.84   0.830.7826\nsepal width:2.0  4.4   3.05   0.43   -0.4194\npetal length:   1.0  6.9   3.76   1.760.9490  (high!)\npetal width:0.1  2.5   1.20  0.76 0.9565  (high!)\n============== ==== ==== ======= ===== ====================\n\n:Missing Attribute Values: None\n:Class Distribution: 33.3% for each of 3 classes.\n:Creator: R.A. Fisher\n:Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n:Date: July, 1988\n\nThis is a copy of UCI ML iris datasets.\nhttp://archive.ics.uci.edu/ml/datasets/Iris\n\nThe famous Iris database, first used by Sir R.A Fisher\n\nThis is perhaps the best known database to be found in the\npattern recognition literature.  Fisher\'s paper is a classic in the field and\nis referenced frequently to this day.  (See Duda & Hart, for example.)  The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant.  One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\nReferences\n----------\n   - Fisher,R.A. "The use of multiple measurements in taxonomic problems"\n Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n Mathematical Statistics" (John Wiley, NY, 1950).\n   - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.\n (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.\n   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n Structure and Classification Rule for Recognition in Partially Exposed\n Environments".  IEEE Transactions on Pattern Analysis and Machine\n Intelligence, Vol. PAMI-2, No. 1, 67-71.\n   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions\n on Information Theory, May 1972, 431-433.\n   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II\n conceptual clustering system finds 3 classes in the data.\n   - Many, many more ...\n', 'feature_names': ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']}    knn.fit(iris.data, iris.target)#fit(传入特征值矩阵,传入一维的向量标签),这句话就是让kNN建立一个模型    predictedLabel = knn.predict([[0.1, 0.2, 0.3, 0.4]]) #预测这个实例分类是哪一类    # print "hello"    #print ("predictedLabel is :" + predictedLabel)    print predictedLabel  #运行结果 0 ,也就是上面这个实例属于第一类

3. KNN 实现Implementation:

我们不用sklearn中自带的方法,自己手动写的,如下所示
数据集(irisdata.txt)下载地址:https://pan.baidu.com/s/1bpHCyHD
完整源码下载地址:KNNImplementation (myfix).py

    # -*- coding:utf-8 -*-    import csv    import random    import math    import operator    def loadDataset(filename, split, trainingSet = [], testSet = []):#loadDataset装载数据集        with open(filename, 'rb') as csvfile:            lines = csv.reader(csvfile)            dataset = list(lines)            # print dataset            for x in range(len(dataset)-1):                for y in range(4):                    dataset[x][y] = float(dataset[x][y])                if random.random() < split:                    trainingSet.append(dataset[x])                else:                    testSet.append(dataset[x])    def euclideanDistance(instance1, instance2, length):#如何计算距离,不仅仅能够计算二维的,他可以计算任意维度的        #instance1和instance2都是这样的形式[6.2, 3.4, 5.4, 2.3, 'Iris-virginica']        distance = 0        for x in range(length):            distance += pow((instance1[x]-instance2[x]), 2)        return math.sqrt(distance)    def getNeighbors(trainingSet, testInstance, k):#返回邻距        distances = []        length = len(testInstance)-1  # 这里的length = 4也就算是四维的吧        for x in range(len(trainingSet)):            #testinstance            dist = euclideanDistance(testInstance, trainingSet[x], length)            #testInstance和trainingSet[x]都是这样的形式[6.2, 3.4, 5.4, 2.3, 'Iris-virginica']            #dist指的是两个‘点’之间的距离            # print dist            distances.append((trainingSet[x], dist))            #distances就是下面的这种样子            #([6.7, 3.3, 5.7, 2.1, 'Iris-virginica'], 5.568662316930341),            #....            # ([6.2, 2.8, 4.8, 1.8, 'Iris-virginica'], 4.497777228809804)]            #distances.append(dist)        distances.sort(key=operator.itemgetter(1))#这里是根据距离来排序的        neighbors = []        for x in range(k):            neighbors.append(distances[x][0])        return neighbors    #对k个近邻进行合并,返回value最大的key    def getResponse(neighbors): #这个方法就是返回某个点预测的是哪一类        classVotes = {}        for x in range(len(neighbors)): #len(neighbors) 其实就是3            response = neighbors[x][-1]            if response in classVotes:                classVotes[response] += 1            else:                classVotes[response] = 1        # print classVotes        # classVotes是这种样子的:        # {'Iris-virginica': 2, 'Iris-versicolor': 1}        # {'Iris-virginica': 3}        sortedVotes = sorted(classVotes.iteritems(), key=operator.itemgetter(1), reverse=True)        return sortedVotes[0][0]    def getAccuracy(testSet, predictions):        correct = 0        for x in range(len(testSet)):            if testSet[x][-1] == predictions[x]:                correct += 1        return (correct/float(len(testSet)))*100.0    def main():        #prepare data        trainingSet = []#训练数据集        testSet = [] #测试数据集        split = 0.67#训练数据集和测试数据集分割的比例        loadDataset('C:\Users\zmj\Desktop\irisdata.txt', split, trainingSet, testSet)#split分割训练集和测试集        print 'Train set: ' + repr(len(trainingSet))        print 'Test set: ' + repr(len(testSet))        #generate predictions        predictions = []        k = 3        for x in range(len(testSet)):            # trainingsettrainingSet[x]            neighbors = getNeighbors(trainingSet, testSet[x], k)            #testSet[x]指的是测试集,每个testSet[x]是这样的[6.2, 3.4, 5.4, 2.3, 'Iris-virginica']            # neithbors 指的是返回K个实例,            # print neighbors            # neighbors是如下这种样子的            # [[6.4, 3.2, 5.3, 2.3, 'Iris-virginica'], [6.5, 3.2, 5.1, 2.0, 'Iris-virginica'],             # [6.5, 3.0, 5.2, 2.0, 'Iris-virginica']]            result = getResponse(neighbors)            predictions.append(result)            print ('>predicted=' + repr(result) + ', actual=' + repr(testSet[x][-1]))        accuracy = getAccuracy(testSet, predictions)        print('Accuracy: ' + repr(accuracy) + '%')    if __name__ == '__main__':        main()

代码中出现很多Python排序的例如:[list].sort(),sorted([list])等,详情请看我另外一篇文章http://blog.csdn.net/zhongjunlang/article/details/78155118

阅读全文
1 0
原创粉丝点击