Kaggel实战:识别手写体[knn改进算法]

来源:互联网 发布:怎么查看端口是否打开 编辑:程序博客网 时间:2024/06/07 12:37

说明

  • 未采用sklearn自带的knn算法(当时得分96.800%)进行建模
  • 改进大神的代码(96.400% )提高到96.886%

代码

from numpy import *import operatorimport csvdef toInt(array):    array=mat(array)    m,n=shape(array)    newArray=zeros((m,n))    for i in xrange(m):        for j in xrange(n):                newArray[i,j]=int(array[i,j])    return newArraydef loadTrainData():    l=[]    with open('train.csv') as file:         lines=csv.reader(file)         for line in lines:             l.append(line) #42001*785    l.remove(l[0])    l=array(l)    label=l[:,0]    data=l[:,1:]    return toInt(data),toInt(label)  #label 1*42000  data 42000*784    #return data,labeldef loadTestData():    l=[]    with open('test.csv') as file:         lines=csv.reader(file)         for line in lines:             l.append(line)     #28001*784    l.remove(l[0])    data=array(l)    return toInt(data)  #  data 28000*784#dataSet:m*n   labels:m*1  inX:1*ndef classify(inX, dataSet, labels, k):    inX=mat(inX)    dataSet=mat(dataSet)    labels=mat(labels)    dataSetSize = dataSet.shape[0]                      diffMat = tile(inX, (dataSetSize,1)) - dataSet       sqDiffMat = array(diffMat)**2    sqDistances = sqDiffMat.sum(axis=1)                      distances = sqDistances**0.5    sortedDistIndicies = distances.argsort()                classCount={}                                          for i in range(k):        voteIlabel = labels[sortedDistIndicies[i],0]        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)    return sortedClassCount[0][0]def saveResult(result):    with open('result.csv','wb') as myFile:            myWriter=csv.writer(myFile)        for i in result:            tmp=[]            tmp.append(i)            myWriter.writerow(tmp)def handwritingClassTest():    trainData,trainLabel=loadTrainData()    testData=loadTestData()    m,n=shape(testData)    resultList=[]    for i in range(m):         classifierResult = classify(testData[i], trainData, trainLabel.transpose(), 5)         resultList.append(classifierResult)    saveResult(resultList)handwritingClassTest()

7月12日

  • 源代码未变,将k值设置为3,准确率提高到了96.929%
原创粉丝点击