KNN

来源:互联网 发布:大乐透软件破解预测 编辑:程序博客网 时间:2024/05/16 18:14

缺点:

耗内存,存储所有训练样本,对每个测试样本都要计算和所有训练数据的距离,时间成本高

knn 和 Locally weighted linear regression 思想上非常相似,对每个预测点都需要训练单独训练模型

代码如下:

from numpy import *
import os


def loaddata():
    data=array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])  
    labels = ['A', 'A', 'B', 'B'] 
    return data,labels


def img2vect(filename):
    rows=32
    cols=32
    imgvect=zeros((1,rows*cols))
    f=open(filename)
    for i in range(rows):
        linstr=f.readline()
        for j in range(cols):
            imgvect[0,i*rows+j]=int(linstr[j])
    return imgvect


def loaddigit():
    print('get training sets:')
    filedir='C:/Users/yourname/Desktop/machin/digit/'
    trainfilelist=os.listdir(filedir+'trainingDigits')
    numsamples=len(trainfilelist)
    train_x=zeros((numsamples,1024))
    train_y=[]
    for i in range(numsamples):
        filename=trainfilelist[i]
        train_x[i,:]=img2vect(filedir+'trainingDigits/%s' % filename)
        label=int(filename.split('_')[0])
        train_y.append(label)
        
    print( "---Getting testing set..."  )
    testingFileList = os.listdir(filedir + 'testDigits')   
    numSamples = len(testingFileList)  
    test_x = zeros((numSamples, 1024))  
    test_y = []  
    for i in range(numSamples):  
        filename = testingFileList[i]  
        test_x[i, :] = img2vect(filedir + 'testDigits/%s' % filename)   
        label = int(filename.split('_')[0])
        test_y.append(label)  
  
    return train_x, train_y, test_x, test_y


def knn(dataset,testdata,labels,k):
    n=dataset.shape[0]
    diff=tile(testdata,(n,1))-dataset
    sdiff=diff**2
    sumdist=sum(sdiff,axis=1)
    dist=sumdist**0.5
    dist_sort=argsort(dist)
    classcount={}
    for i in range(k):
        label=labels[dist_sort[i]]
        classcount[label]=classcount.get(label,0)+1
    maxcount=0
    for key,value in classcount.items():
        if value>maxcount:
            maxcount=value
            maxindex=key
    return maxindex


def testHandWritingClass():
    train_x, train_y, test_x, test_y=loaddigit()
    numtestsamples=test_x.shape[0]
    matchcount=0
    for i in range(numtestsamples):
        predict=knn(train_x,test_x[i],train_y,3)
        if predict==test_y[i]:
            matchcount+=1
    accuracy=float(matchcount)/numtestsamples
    print('The classify accuracy is: %.2f%%' % (accuracy * 100))


data,labels=loaddata()
testdata=[0.8,0.7]
ke=knn(data,testdata,labels,2)
print(ke)
testHandWritingClass()

0 0
原创粉丝点击