KNN(k近邻算法)

来源:互联网 发布:js如何获取鼠标位置 编辑:程序博客网 时间:2024/05/11 03:11

KNN算法:

  • 优点:精确度高,有离群点效果稳定,不需要关于数据的任何假设
  • 缺点:需要大量的计算,需要许多内存
  • 适用范围:目标属性数值型和标称型

工作过程:

现在假如我们已经拥有了一些数据(称为训练数据集)TS,并且拥有所有数据的类别名---即每条数据应该归于哪个类别。当我们要判断一条不知类别的数据时,首先让这条数据M和已经拥有的所有的数据TS中的每一条数据进行比较,然后根据比较结果选择出和M最相似(一般是基于距离)的K条数据(K是个整数,并且通常是小于20),最后,然后根据他们的主要分类来决定新数据的类别

K近邻算法的一般流程:

  1. 收集数据,可以使用任何方法
  2. 准备数据,距离计算所需要的数值,最好是结构化的数据格式
  3. 分析数据:可以使用任何方法
  4. 训练算法:不需要
  5. 测试算法:计算错误率
  6. 使用算法:首先需要输入样本数据和结构化的输出结果,然后运行k-近邻算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理
实例:机器学习实战中约会对象判断例子的代码
from numpy import *
import operator
from os import listdir


def createDateSet():
    group=array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels=['A','A','B','B']
    return group,labels
#----------------------------------------------------------------------
def classify0(inX, dataSet, labels, k):
    """
    simple classifier
    """
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


#----------------------------------------------------------------------
def file2matrix(filename):
    fr = open(filename)
    numberOfLines = len(fr.readlines())         #get the number of lines in the file
    returnMat = zeros((numberOfLines,3))        #prepare matrix to return
    classLabelVector = []                       #prepare labels return   
    fr = open(filename)
    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat,classLabelVector




#----------------------------------------------------------------------
def autoNorm(dataSet): 
    """ normalizing numeric values
    """
    minVals=dataSet.min(0)
    maxVals=dataSet.max(0) 
    ranges=maxVals-minVals
    normDataSet=zeros(shape(dataSet) )
    m=dataSet.shape[0]
    normDataSet=dataSet-tile(minVals, (m,1))
    normDataSet=normDataSet/tile(ranges, (m,1))
    return normDataSet, ranges, minVals


#----------------------------------------------------------------------
def datingClassyTest():
    """
    test the classifier
    """
    hoRaio=0.05
    datingDataMata,datingLabel=file2matrix('datingTestSet2.txt')
    normMat,ranges,minVals=autoNorm(datingDataMata)
    m=normMat.shape[0]
    numTestVecs=int(m*hoRaio)
    errorCount=0.0
    for i in range(numTestVecs):
        classifierResult=classify0(normMat[i,:], normMat[numTestVecs:m,:], datingLabel[numTestVecs:m], 3)
        print "the classifier came back with: %d ,the real number is : %d" %(classifierResult,datingLabel[i])
        if(classifierResult!=datingLabel[i]):
            errorCount+=1.0
    print 'the total error rate is : %f' % (errorCount/float(numTestVecs))
    
    

    


0 0
原创粉丝点击