【Kaggle练习赛】之Digit Recognizer

来源:互联网 发布:传世1.0源码 编辑:程序博客网 时间:2024/04/29 16:45

Kaggle是国外的一项数据挖掘赛事,近期阿里并没有开办赛事,所以准备先拿Kaggle的练习赛来热热身,顺便学习一下scikit-learn这个开源库的使用。Kaggle入门可以参见 http://blog.csdn.net/u012162613/article/details/41929171

一.问题描述

The goal in this competition is to take an image of a handwritten single digit, and determine what that digit is. As the competition progresses, we will release tutorials which explain different machine learning algorithms and help you to get started.

The data for this competition were taken from the MNIST dataset. The MNIST (“Modified National Institute of Standards and Technology”) dataset is a classic within the Machine Learning community that has been extensively studied. More detail about the dataset, including Machine Learning algorithms that have been tried on it and their levels of success, can be found at http://yann.lecun.com/exdb/mnist/index.html.

本题的要求是识别手写字符,每个手写字符是一个28*28的灰度图像,从0-9。训练集里面是有42000个样本,测试集里是有28000个样本,每个测试样本也都是28*28的向量。我们的任务是对测试样本进行分类,判断每个测试样本属于0-9中的哪个字符。

二.解题思路

这道题实际上也可以视为模式识别里面的问题。现有的最好的解决方案是利用深度学习模型(SAE,CNN)来提取特征,然后训练SVM分类器进行分类。出于练习的目的,我并没有这样做,而是直接调用scikit-learn d 的KNN算法来解决。我的测试程序如下:

import csvfrom numpy import *import matplotlib.pyplot as pltfrom matplotlib.colors import ListedColormapfrom sklearn import neighbors, datasetsdef top_test():    train_data,train_label =loadTrainData()#res=digit.knn_classify(train_data,train_label)    test_data = load_test_data()    result = knn_classify(train_data,train_label,test_data)    gen_res_file(result)def loadTrainData():      file = open('train.csv','rb')    lines=csv.reader(file)      l=[];train_label=[];train_data=[]      for line in lines:         l.append(line)     l.remove(l[0])    for line in l:        train_label.append(int(line[0]))        tmp=[]        tmp=line[1:];curr=[]        for chrac in tmp:            if(int(chrac)>0):                tmp_num=1            else:                tmp_num=0            curr.append(tmp_num)        train_data.append(curr)    return train_data,train_labeldef load_test_data():    file = open('test.csv','rb')    lines=csv.reader(file)      l=[];test_data=[]      for line in lines:         l.append(line)     l.remove(l[0])    for line in l:                tmp=line;curr=[]        for chrac in tmp:            if(int(chrac)>0):                tmp_num=1            else:                tmp_num=0            curr.append(tmp_num)        test_data.append(curr)    return test_datadef test_knn_classify(train_data,train_label):      train_data=array(train_data)      train_label=array(train_label)      n_neighbors=15      weights = 'uniform'      clf = neighbors.KNeighborsClassifier(n_neighbors, weights)      clf.fit(train_data, train_label)      res = clf.predict(train_data[25])      return resdef knn_classify(train_data,train_label,test_data):      train_data=array(train_data)      train_label=array(train_label)      test_data = array(test_data)      n_neighbors=15      weights = 'uniform'      clf = neighbors.KNeighborsClassifier(n_neighbors, weights) //配置scikit-learn的KNN      clf.fit(train_data, train_label)//拟合数据      res = clf.predict(test_data) //进行预测      return resdef gen_res_file(result):     file = open('result.csv','wb')       my_save=csv.writer(file)       tmp=['ImageId','Label']     my_save.writerow(tmp)      cnt=1     for i in result:              tmp=[]              tmp.append(cnt)            tmp.append(i)              cnt=cnt+1            print cnt            my_save.writerow(tmp)      file.close()

运行的时候,直接调用digit.top_test()即可得到result.csv。提交结果,正确率是95.74%。

0 0
原创粉丝点击