《Machine Learning in Action》 读书笔记之一:K近邻分类器算法

来源:互联网 发布:云计算的三种形式 编辑:程序博客网 时间:2024/05/22 02:28

1. 读取data 

def file2matrix(filename):    fr = open(filename)    numberOfLines = len(fr.readlines())         #get the number of lines in the file    returnMat = zeros((numberOfLines,3))        #prepare matrix to return    classLabelVector = []                       #prepare labels return       fr = open(filename)    index = 0    for line in fr.readlines():        line = line.strip()        listFromLine = line.split('\t')        returnMat[index,:] = listFromLine[0:3]        classLabelVector.append(int(listFromLine[-1]))        index += 1    return returnMat,classLabelVector

2. 数据的显示

from numpy import *import kNNimport matplotlibimport matplotlib.pyplot as pltfig = plt.figure()ax = fig.add_subplot(111)datingDataMat,datingLabels = kNN.file2matrix('data.txt')#ax.scatter(datingDataMat[:,1], datingDataMat[:,2])ax.scatter(datingDataMat[:,0], datingDataMat[:,1], 15.0*array(datingLabels), 15.0*array(datingLabels))#ax.axis([-2,25,-0.2,2.0])ax.axis([-2,100000,-0.2,25])#ax.legend('Did','Small','Large')plt.xlabel('Percentage of Time Spent Playing Video Games')plt.ylabel('Liters of Ice Cream Consumed Per Week')plt.show()


3. 数据的归一化

def autoNorm(dataSet):    minVals = dataSet.min(0)    maxVals = dataSet.max(0)    ranges = maxVals - minVals    normDataSet = zeros(shape(dataSet))    m = dataSet.shape[0]    normDataSet = dataSet - tile(minVals, (m,1))    normDataSet = normDataSet/tile(ranges, (m,1))   #element wise divide    return normDataSet, ranges, minVals


4.用kNN分类器进行分类

def datingClassTest():    hoRatio = 0.10      #hold out 10%    datingDataMat,datingLabels = file2matrix('data.txt')       #load data setfrom file    normMat, ranges, minVals = autoNorm(datingDataMat)    m = normMat.shape[0]    numTestVecs = int(m*hoRatio)    errorCount = 0.0    for i in range(numTestVecs):        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)     #here set k=3        print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])        if (classifierResult != datingLabels[i]): errorCount += 1.0    print "the total error rate is: %f" % (errorCount/float(numTestVecs))    print errorCount




0 0
原创粉丝点击