KNN python code

来源:互联网 发布:淘宝打不开怎么办 编辑:程序博客网 时间:2024/04/25 15:57

几分钟写了个KNN Python代码,在编译器上可以直接跑:


"""programs: KNN algorithmdescription:1.calculate the distance between test data and every single train data2.sort the distance 3.select the minimum k points by distance4.count the label frequency of k points5.return to the label of the highest frequency"""from mlxtend.data import iris_dataimport numpy as npclass knn_csy(object):    def __init__(self,dataset,label):        self.dataset=dataset        self.label=label    def distance(self,dataset_i,testdata):        dist=np.sum((dataset_i-testdata)**2)        return np.sqrt(dist)    def calculate_dis(self,testdata,k=10,updateflage=0):        """                :param testdata:         :param k: default by 10        :param updateflage:         :return:         """        if len(testdata)!=len(self.dataset[0]):            raise Exception("wrong input array of testdata");        dis=[]        dimension=len(self.dataset)        for i in range(dimension):            distance=self.distance(self.dataset[i],testdata)            dis.append(distance)        dic=zip(dis,self.label)        dic=sorted(dic)        label=[]        for i in range(k):            label.append(dic[i][1])        count=np.bincount(label)        label=np.argmax(count)        if updateflage:            self.dataset.append(testdata)            self.label.append(label)        return labelif __name__ == '__main__':    dataset,label=iris_data()    myknn=knn_csy(dataset,label)    testdata=[2,1,1,2]    label=myknn.calculate_dis(testdata,3)    print label


原创粉丝点击