机器学习算法之KNN识别mnist数据集

来源:互联网 发布:效果图设计软件 编辑:程序博客网 时间:2024/06/08 12:08

KNN算法又称K邻近算法(K Nearest Neighbor),其基本思想为:样本空间中,某样本的类别为距离其最近的k个邻居所属类别中最多的那个类别。

MNIST数据集为一个带标注的手写数字识别的数据集,其官方下载地址为http://yann.lecun.com/exdb/mnist/。数据集包含60000个训练集和10000个测试集。该数据集中的文件以二进制方式保存,因此在读入时需要以二进制方式打开。

在学习读入MNIST数据集过程中,最大的收获是对Python中struct包的简单的了解。其中,对MNIST数据集的读入,主要参考博客http://www.cnblogs.com/x1957/archive/2012/06/02/2531503.html。个人觉得这一篇讲的还是很清楚的。

至于算法实现,有两个收获:

1. 算法优化:初始算法步骤比较冗长,肆无忌惮的使用for循环。这种做法对于MNIST这种数据量相对较大的数据集还是很致命的。改进后的算法,其运行时间明显减少(虽然还是很大安静)。

2. Numpy库中的argsort(list):以前对列表进行排序后想要获取对应的列表的索引,各种index,而且容易出现值重合的现象。但是,argsort(list)真的是神器。它对列表的值进行排序(降序)后,返回的列表值对应的索引。有木有很方便!

下面就是代码啦(Python):

读入图片,转化为一个数组(每一行为对应图片的像素):

# read the image file# input: file path#output: the list of piexl array for each imagedef read_image(file_path):    f_open=open(file_path,"rb")    content=f_open.read()    index=0    magic, num_images,num_rows,num_columns=struct.unpack_from(">IIII",content,index) # 以大端法读入四个unsigned int    print("number of images:"+str(num_images))    print("number of rows:"+str(num_rows))    print("number of columns:"+str(num_columns))    index+=struct.calcsize(">IIII")    img_piexl=[]    for i in range(num_images):        piexl_all=[]        for j in range(num_columns):            for k in range(num_rows):                piexl=struct.unpack_from(">B",content,index)                ### This way using the original piexl                #piexl=int(piexl[0])                ### This way using the processed piexl                ### if piexl<127 then it becomes 0                ### else it becomes 1                piexl=int(piexl[0])                if piexl<127:                    piexl=0                else:                    piexl=1                ###                piexl_all.append(piexl)                index+=struct.calcsize(">B")        piexl_all=np.array(piexl_all)        img_piexl.append(piexl_all)        """        print(piexl_all)        piexl_all=piexl_all.reshape(28,28)        fig=plt.figure()        plotwindow=fig.add_subplot(111)        plotwindow.imshow(piexl_all,cmap="gray")        plt.show()        """        if i%1000==0:            print(str(i)+"images have been processed")    f_open.close()    return img_piexl


读入标签,同样转化为一个数组(每一行为对应图片的标签):

# read label file# input: file path# output: list of labelsdef read_label(file_path):    f_open=open(file_path,"rb")    content=f_open.read()    index=0    magic, num_items=struct.unpack_from(">II",content,index)    print("number of labels:"+str(num_items))    index+=struct.calcsize(">II")    label_num=[]    for i in range(num_items):        label=struct.unpack_from(">B",content,index)        label_num.append(int(label[0]))        index+=struct.calcsize(">B")        if i%1000==0:            print(str(i)+"labels have been processed!")    f_open.close()    return label_num

计算两个图片像素之间的距离:

# for each train image and test image, calculating their distance# input: list of piexls for train image, list of piexls for test image# output: the distance between train image and test imagedef calc_dis(train_image,test_image):    dist=np.linalg.norm(train_image-test_image)    return dist

对于单个图片,寻找其最邻近的k个邻居对应的标签以及标签数目的统计:

# find labels for test image# input: the number of neighbors, the list of training images the list of training labels, the test image# output: the dictionary whose key is label and value is its corresponding appearing timedef find_labels(k,train_images,train_labels,test_image):    all_dis = []    labels=defaultdict(int)    for i in range(len(train_images)):        dis = np.linalg.norm(train_images[i]-test_image)        all_dis.append(dis)    sorted_dis = np.argsort(all_dis)    count = 0    while (count < k):        labels[train_labels[sorted_dis[count]]]+=1        count += 1    return labels

KNN算法:

# for all test images, finding its labels by knn# input: number of neighbors, list of train images, list of train labels, list of test images# output: result of labels for each imagedef knn_all(k,train_images,train_labels,test_images):    print("start knn_all!")    res=[]    count=0    for i in range(len(test_images)):        labels=find_labels(k,train_images,train_labels,test_images[i])        res.append(max(labels))        if count%1000==0:            print("%d has been processed!"%(count))        count+=1    return res

计算分类准确率(accuracy):

# calculate the precision of knn result# input: the list of result labels, the list of test labels# output: the precision of label resultsdef calc_precision(res,test_labels):    f_res_open=open("res.txt","a+")    precision=0    for i in range(len(res)):        f_res_open.write("res:"+str(res[i])+"\n")        f_res_open.write("test:"+str(test_labels[i])+"\n")        if res[i]==test_labels[i]:            precision+=1    return precision/len(res)


主程序:

image_train_file_path="E:/mnist/train-images-idx3-ubyte/train-images.idx3-ubyte"label_train_file_path="E:/mnist/train-labels-idx1-ubyte/train-labels.idx1-ubyte"image_test_file_path="E:/mnist/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte"label_test_file_path="E:/mnist/t10k-labels-idx1-ubyte/t10k-labels.idx1-ubyte"image_train_piexl=read_image(image_train_file_path)label_train=read_label(label_train_file_path)image_test_piexl=read_image(image_test_file_path)label_test=read_label(label_test_file_path)print("reading all files completed!")k=5start_time=time.clock()res=knn_all(k,image_train_piexl,label_train,image_test_piexl)end_time=time.clock()print("precision:"+str(calc_precision(res,label_test)))print("running time"+str(end_time-start_time))


实验结果截图:

k=1:


k=5:


k=10:


k越大,其可能带来的噪声越多,准确率越低。


关于该算法,还有一种实现,就是使用sklearn(Python库)

关于这个库中的knn分类器,http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html这个网址讲的比较清楚

代码的话,自己只修改了主程序,精确度还是很高的。

主程序:

image_train_file_path="E:/mnist/train-images-idx3-ubyte/train-images.idx3-ubyte"label_train_file_path="E:/mnist/train-labels-idx1-ubyte/train-labels.idx1-ubyte"image_test_file_path="E:/mnist/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte"label_test_file_path="E:/mnist/t10k-labels-idx1-ubyte/t10k-labels.idx1-ubyte"image_train_piexl=read_image(image_train_file_path)label_train=read_label(label_train_file_path)image_test_piexl=read_image(image_test_file_path)label_test=read_label(label_test_file_path)print("reading all files completed!")# another way of knn# using sklearnstart_time=time.clock()model=KNeighborsClassifier(n_neighbors=5)model.fit(image_train_piexl,label_train)"""res=model.predict(image_test_piexl)# predict the resultend_time=time.clock()print("the precision is: %f"%(calc_precision(res,label_test)))"""print("the precision calculated by sklearn is: %f"%(model.score(image_test_piexl,label_test))) # output the precision of the result directlyend_time=time.clock()print("running time is: %d"%(end_time-start_time))


最后,关于精确度(precision),准确率(accuracy)和召回率(recall),这个知乎的回答解释的很清楚https://www.zhihu.com/question/19645541。希望有所帮助。

最后的最后,虽然不尽如人意,但是还是要好好做自己,努力走向自己喜欢的地方不是吗?国庆节快乐,中秋节要花好月圆人长久(终于体会到这句话多美好了~)



原创粉丝点击