『机器学习实战』使用 k-近邻算法识别手写数字

来源:互联网 发布:怎么设置淘宝店铺红包 编辑:程序博客网 时间:2024/05/16 08:04

算法:

from numpy import *import operatordef createDataSet():group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])labels = ['A', 'A', 'B', 'B']return group, labelsdef classify0(inX, dataSet, labels, k):dataSetSize = dataSet.shape[0]diffMat = tile(inX, (dataSetSize, 1)) - dataSetsqDiffMat = diffMat ** 2sqDistances = sqDiffMat.sum(axis=1)distances = sqDistances ** 0.5sortedDistIndicies = distances.argsort()classCount = {}for i in range(k):voteIlabel = labels[sortedDistIndicies[i]]classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse=True)print 'sorted class count: ', sortedClassCountreturn sortedClassCount[0][0]def file2matrix(filename):fr = open(filename)arrayOLines = fr.readlines()numberOfLines = len(arrayOLines)returnMat = zeros((numberOfLines, 3))classLabelVector = []index = 0for line in arrayOLines:line = line.strip()listFromLine = line.split('\t')returnMat[index, :] = listFromLine[0: 3]classLabelVector.append(int(listFromLine[-1]))index += 1return returnMat, classLabelVectordef autoNorm(dataSet):minVals = dataSet.min(0)maxVals = dataSet.max(0)ranges = maxVals - minValsnormDataSet = zeros(shape(dataSet))m = dataSet.shape[0]normDataSet = dataSet - tile(minVals, (m, 1))normDataSet = normDataSet / tile(ranges, (m, 1))return normDataSet, ranges, minValsdef datingClassTest():hoRatio = 0.10datingDataMat, datingLabels = file2matrix("datingTestSet2.txt")normMat, ranges, minVals = autoNorm(datingDataMat)m = normMat.shape[0]numTestVecs = int(m * hoRatio)errorCount = 0.0for i in range(numTestVecs):classifierResult = classify0(normMat[i, :], normMat[numTestVecs: m, :], \datingLabels[numTestVecs: m], 3)print "the classifier came back with: %d, the real answer is: %d" \% (classifierResult, datingLabels[i])if (classifierResult != datingLabels[i]):errorCount += 1.0print "the total error rate is: %f" % (errorCount / float(numTestVecs))def classifyPerson():resultList = ['not at all', 'in samall doses', 'in large doses']percentTats = float(raw_input(\"percentage of time spent playing video games?"))ffMiles = float(raw_input(\"frequent flier miles earned per year?" ))iceCream = float(raw_input("liters of ice cream consumed per year?"))datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')normMat, ranges, minVals = autoNorm(datingDataMat)inArr = array([ffMiles, percentTats, iceCream])classifierResult = classify0((inArr - minVals) / ranges, normMat, datingLabels, 3)print "You will probably like this person: ", \resultList[classifierResult - 1]def img2vector(filename):returnVect = zeros((1, 1024))fr = open(filename)for i in range(32):lineStr = fr.readline()for j in range(32):returnVect[0, 32*i+j] = int(lineStr[j])return returnVectdef handwritingClassTest(k=3):from os import *hwLabels = []trainingFileList = listdir('trainingDigits')m = len(trainingFileList)trainingMat = zeros((m, 1024))for i in range(m):fileNameStr = trainingFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])hwLabels.append(classNumStr)trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)testFileList = listdir('testDigits')errorCount = 0.0mTest = len(testFileList)for i in range(mTest):fileNameStr = testFileList[i]fileStr = fileNameStr.split('.')[0]classNumStr = int(fileStr.split('_')[0])vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)classifierResult = classify0(vectorUnderTest, \trainingMat, hwLabels, k)print "the classifier came back with: %d, the real answer is : %d" \% (classifierResult, classNumStr)if (classifierResult != classNumStr):errorCount += 1.0print '\nthe total number of errors is: %d' % errorCountprint "\nthe total error rate is: %f" % (errorCount / float(mTest))


运行算法:

import kNN_digits_LtestVector = kNN_digits_L.img2vector('testDigits/0_13.txt')print testVectorprint testVector[0, 0: 31]print testVector[0, 32: 63]kNN_digits_L.handwritingClassTest()



阅读全文
0 0
原创粉丝点击