机器学习实战——k-邻近算法

来源:互联网 发布:java多线程经典书籍 编辑:程序博客网 时间:2024/04/30 02:23
实验数据:
from numpy import *from os import listdirimport matplotlibimport matplotlib.pyplot as pltimport operator#产生数据样本def createDataSet():    group=array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])    labels=['A','A','B','B']    return group,labels#knn分类器def mycalssify(inX,dataSet,labels,k):    dataSetSize=dataSet.shape[0]    diffMat=tile(inX,(dataSetSize,1))-dataSet    sqDiffMat=diffMat**2    sqDistance=sqDiffMat.sum(axis=1)    distance=sqDistance**0.5    sortedDisIndicies=distance.argsort()    classCount={}    for i in range(k):        voteIlabel=labels[sortedDisIndicies[i]]        classCount[voteIlabel]=classCount.get(voteIlabel,0)+1    sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)    return sortedClassCount[0][0]#文本转numpy矩阵(仍然是数组)格式def file2matrix(filename):    fr=open(filename)    arrayOLines=fr.readlines()    numberOfLines=len(arrayOLines)    returnMat=zeros((numberOfLines,3))    classLabelVector=[]    index=0    for line in arrayOLines:        line=line.strip()        listFromLine=line.split('\t')        returnMat[index,:]=listFromLine[0:3]        classLabelVector.append(int(listFromLine[-1]))        index+=1    return returnMat,classLabelVector#数值归一化def autoNorm(dataSet):    minVal=dataSet.min(0)    maxVal=dataSet.max(0)    ranges=maxVal-minVal    normDataSet=zeros(shape(dataSet))    m=dataSet.shape[0]    normDataSet=dataSet-tile(minVal,(m,1))    normDataSet=normDataSet/tile(ranges,(m,1))    return normDataSet,ranges,minVal#测试约会网站def datingClassTest():    hoRatio=0.10    datingDataMat,datingLabels=file2matrix('datingTestSet2.txt')    normMat,ranges,minVal=autoNorm(datingDataMat)    m=normMat.shape[0]    numTestVecs=int(m*hoRatio)    errorCount=0.0    for i in range(numTestVecs):        classifierResult=mycalssify(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],4)        print('the myclass came back with:%d,the real answer is %d' %(classifierResult,datingLabels[i]))        if classifierResult != datingLabels[i]:            errorCount+=1.0    print('the total error rate is:%f' %(errorCount/float(numTestVecs)))#交互式输入测试def classifyPerson():    resultList=['not at all','in small doses','in large doses']    percentTats=float(input('playing video game:'))    ffMiles=float(input('miles:'))    iceCream=float(input('icecream:'))    datingDataMat,datingLabels=file2matrix('datingTestSet2.txt')    normMat,ranges,minVals=autoNorm(datingDataMat)    inArr=array([percentTats,ffMiles,iceCream])    result=mycalssify((inArr-minVals)/ranges,normMat,datingLabels,3)    print('you will probably like this person:',resultList[result-1])#图像文本转向量def img2vector(filename):    returnVect=zeros((1,1024))    fr=open(filename)    for i in range(32):        lineStr=fr.readline()        for j in range(32):            returnVect[0,32*i+j]=int(lineStr[j])    return returnVect#测试def handwritingClassTest():    hwLabels=[]    trainingFileList=listdir('trainingDigits')    m=len(trainingFileList)    trainingMat=zeros((m,1024))    for i in range(m):        fileNameStr=trainingFileList[i]        fileStr=fileNameStr.split('.')[0]        classNumStr=fileStr.split('_')[0]        hwLabels.append(classNumStr)        trainingMat[i,:]=img2vector('trainingDigits/%s' %fileNameStr)    testFileList=listdir('testDigits')    errorCount=0.0    mTest=len(testFileList)    for i in range(mTest):        fileNameStr=testFileList[i]        fileStr=fileNameStr.split('.')[0]        classNumStr = fileStr.split('_')[0]        vectorUnderTest=img2vector('testDigits/%s' %fileNameStr)        classifileResult=mycalssify(vectorUnderTest,trainingMat,hwLabels,3)        print('the classifier came back with: %s,the real answei is:%s' %(classifileResult,classNumStr))        if classifileResult!=classNumStr:            errorCount+=1.0    print('\n the total number of error is: %d'%errorCount)    print('\n the total error rate is : %f' %(errorCount/float(mTest)))#使用matplotlib画图:datingDataMat,datingLabels=file2matrix('datingTestSet2.txt')fig=plt.figure()ax=fig.add_subplot(111)ax.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*array(datingLabels),15.0*array(datingLabels))#plt.show()#datingClassTest()#classifyPerson()#handwritingClassTest()
原创粉丝点击