《机器学习实战》-kNN算法手写算法识别

来源:互联网 发布:nginx 性能优化 编辑:程序博客网 时间:2024/05/22 14:50

   通过观看机器学习实战这本书,有了些许读后感,下面是我理解这本书里面的KNN算法,希望阔以帮助你们稍微加强一下理解微笑数据集代码下载

KNN算法其实就是邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。
也就是说离你最近的k个点中,大多数点属于的类别也就是这个样本属于的类别,也就是用俗语说,物以类聚,人以群分。属于较为简单的算法

#coding=UTF8from numpy import *import operatorfrom os import listdirdef classify0(inX, dataset, labels, k):    dataSetSize = dataset.shape[0]#训练集的行数    diffMat = tile(inX, (dataSetSize, 1)) - dataset#tile函数复制datasize行    sqDiffMat = diffMat ** 2#平方    sqDistance = sqDiffMat.sum(axis=1)#横向求和    distance = sqDistance ** 0.5#开方    sortedDistIndicies = distance.argsort()#排序    classCount = {}    for i in range(k):        voteIlabel = labels[sortedDistIndicies[i]]        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)    return sortedClassCount[0][0]def img2vector(filename):    returnVect = zeros((1,1024))#返回一行,1024列的数组    fr = open(filename)    for i in range(32):        lineStr = fr.readline()#一行一行读取        for j in range(32):            returnVect[0,32*i+j] = int(lineStr[j])#把文件中的数据存入到里面,键值对应    return returnVect#返回一个装着所有数据的列表def handwritingClassTest():    hwLabels = []    # 加载训练数据    trainingFileList = listdir('trainingDigits')    m = len(trainingFileList)#文件的个数    trainingMat = zeros((m,1024))#建立一个m行,1024列的数组,每行存入一个文件    for i in range(m):        fileNameStr = trainingFileList[i]#第 i个文件        fileStr = fileNameStr.split('.')[0]     #分割文件名字        classNumStr = int(fileStr.split('_')[0])        hwLabels.append(classNumStr)#添加标签,也就是数字是几,添加到一个列表里面        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)#把每一个文件一行一行的存入到训练集列表里面    testFileList = listdir('testDigits')        #iterate through the test set    errorCount = 0.0    mTest = len(testFileList)    for i in range(mTest):        fileNameStr = testFileList[i]        fileStr = fileNameStr.split('.')[0]     #take off .txt        classNumStr = int(fileStr.split('_')[0])        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)#只有一行的列表        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)#计算测试集和每个训练集的距离        print "the classifier came back with: %d, the real answer is: %d, The predict result is: %s" % (classifierResult, classNumStr, classifierResult==classNumStr)        if (classifierResult != classNumStr): errorCount += 1.0    print "\nthe total number of errors is: %d / %d" %(errorCount, mTest)    print "\nthe total error rate is: %f" % (errorCount/float(mTest))if __name__== "__main__":    handwritingClassTest()

0 0