Python 简单实现KNN算法

来源:互联网 发布:快三遗漏数据 编辑:程序博客网 时间:2024/05/21 07:47

数据集是自己下载的mnist的手写识别的数据,有一个train.csv文本,一个test.csv测试文本,还有一个submission.csv文本(存放的是test.csv的标签),不多说了,KNN原理很简单,直接上代码吧


#autor:zhouchao#date:2017-12-07 11:13#description:use knn to recognize numimport numpy as npfrom numpy import *import operatorfrom numpy import random  def load_train_data(path):train=np.loadtxt(path,delimiter=",", skiprows=0)vec=train[:,1:]labels=train[:,0:1].tolist()print type(labels)return vec,labelsdef predict(line,vec,labels):numSamples = vec.shape[0]diff = tile(line, (numSamples, 1)) - vecsquaredDiff = diff ** 2squaredDist = sum(squaredDiff, axis = 1)distance = squaredDist ** 0.5sortedDistIndices = argsort(distance)classCount = {}for i in xrange(20):voteLabel = labels[sortedDistIndices[i]][0]classCount[voteLabel] = classCount.get(voteLabel, 0) + 1maxCount = 0for key, value in classCount.items():if value > maxCount:maxCount = valuemaxIndex = keyreturn maxIndex if __name__=="__main__":vec,labels=load_train_data("../../data/handwrite/train.csv")f=open("../../data/handwrite/test.txt")for line in f.readlines():nums = line.split(",")nums = [int(x) for x in nums ]matrix = np.array(nums)print predict(matrix,vec,labels)




原创粉丝点击