转: Kaggle入门模板:以手写识别Digit Recognizer为例

来源:互联网 发布:aws 阿里云 价格 编辑:程序博客网 时间:2024/05/14 17:26

首先本文参考了点击打开链接 这篇博客,然后可能时间有点久远,Kaggle的这道题给的数据文档和之前的不一样了,以及还有一些注意点这篇文章里没有突出。因此这里重新做个总结,希望大家能早点入个门。

这里我使用的sklearn中的支持向量机来解决手写识别问题。这里的svm是可以解决多分类问题的。核函数使用的是高斯核(rbf),松弛变量c选择的是5.

kaggle这道题一共提供了3个文件:train.csv,test.csv,sample_submission.csv 。 分别表示训练集,测试集,提交样例。

下面上python代码。本人的macbook pro16,运行时间为575秒。svm的准确率在这个问题上可能不及knn,但是运行的效率要比knn高了许多。。。

[python] view plain copy
  1. #!/usr/bin/python    
  2. # -*- coding: utf-8 -*-    
  3.     
  4. from numpy import *    
  5. from sklearn import svm      
  6. import csv     
  7. import datetime  
  8.   
  9. #把数组中的字符串转换成整数  
  10. def toInt(array):   
  11.     array=mat(array)    
  12.     m,n=shape(array)    
  13.     #使用xrange不会生成list,性能要优于range  
  14.     for i in xrange(m):    
  15.         for j in xrange(n):    
  16.                 array[i,j]=int(array[i,j])    
  17.     return array    
  18.   
  19. #把大于0的数都置为1  
  20. def nomalizing(array):    
  21.     m,n=shape(array)    
  22.     for i in xrange(m):    
  23.         for j in xrange(n):    
  24.             if array[i,j]!="0":  #注意原csv文件中的数字也是字符串  
  25.                 array[i,j]=1    
  26.             else:  
  27.                 array[i,j]=0  
  28.     return array    
  29.   
  30. def loadTrainData():    
  31.     l=[]    
  32.     with open('train.csv') as file:    
  33.          lines=csv.reader(file)    
  34.          for line in lines:    
  35.              l.append(line) #42001*785    
  36.     l.remove(l[0])  #移除第0行,第0行是数据列名  
  37.     l=array(l)  #将l由list型转化为numpy下的array型  
  38.     label=l[:,0]  #label赋值为l的第0列  
  39.     data=l[:,1:]  #data赋值为l的第1至最后一列  
  40.     return nomalizing(data),toInt(label)   
  41.   
  42. def loadTestData():    
  43.     l=[]    
  44.     with open('test.csv') as file:    
  45.          lines=csv.reader(file)   
  46.          for line in lines:    
  47.              l.append(line)    
  48.     l.remove(l[0])    
  49.     data=array(l)    
  50.     return nomalizing(data)    
  51.   
  52. def saveResult(result,csvName):    
  53.     with open(csvName,'wb') as myFile:        
  54.         myWriter=csv.writer(myFile)   
  55.         num = 1   
  56.         arr=[]  
  57.         arr.append("ImageId")  
  58.         arr.append("Label")  
  59.         myWriter.writerow(arr)  #先将列名插入第0行  
  60.         for i in result:    
  61.             tmp=[]   
  62.             tmp.append(num)  
  63.             num = num + 1   
  64.             tmp.append(int(i))  ##不能是浮点数    
  65.             myWriter.writerow(tmp)    
  66.       
  67. def svcClassify(trainData,trainLabel,testData):     
  68.     svcClf=svm.SVC(C=5.0#default:C=1.0,kernel = 'rbf'. you can try kernel:‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’      
  69.     svcClf.fit(trainData,ravel(trainLabel))    
  70.     testLabel=svcClf.predict(testData)    
  71.     saveResult(testLabel,'sklearn_SVC_C=5.0_Result.csv')    
  72.     return testLabel    
  73.   
  74. def main():    
  75.     starttime = datetime.datetime.now()  
  76.     trainData,trainLabel=loadTrainData()    
  77.     print "训练集读取完毕"  
  78.     testData=loadTestData()     
  79.     print "测试集读取完毕"  
  80.     svcClassify(trainData,trainLabel,testData)  
  81.     endtime = datetime.datetime.now()  
  82.     print "预测结束--程序总运行时间:"+str((endtime - starttime).seconds)+"秒"  
  83.   
  84. main() #主函数  

ps:本人一开始在kaggle上提交结果,总是返回的准确率为0.00000,后来用文本编辑器打开了csv,才发现自己生成的label都是浮点数,而在excel中看不出来,坑。

kaggle提交注意事项:

  1. 每道题目一天最多交5次,大家珍惜每天的提交机会
  2. 提交的csv要严格遵循sample_submission.csv中的格式,也就是在提交文件中第一行的列名也是需要加的,且列名不能出错。
  3. 提交的数据一定要弄清是整数还是浮点数。否则提交后是会被判断为预测错误的。
       kaggle这个平台真心不错,让我找回了codeforces的感觉,感觉找到了一个很好的锻炼动手能力的平台,希望大家能经常做做练习~
原创粉丝点击