周志华新论文gcForest手写数字测试识别详解(Kaggle数据集)

来源:互联网 发布:培训学校收费软件 编辑:程序博客网 时间:2024/06/04 19:31

lamda实验室最近研发了新的学习算法gcForest,论文和代码在lamda网站上都有,给出的代码没有注释相当费解(大牛请无视),后又在github上找到民间大神实现的代码,较为简洁易懂,先贴上代码链接:https://github.com/pylablanche/gcForest

简要对gcForest算法处理数据的流程解释一下:
假设我们现在的数据就是kaggle平台上手写数字识别的数据格式,具体格式见https://www.kaggle.com/c/digit-recognizer/data

那么但数据规模是28*28,假设我们设置扫描窗口window为14*14:

先有多粒度扫描过程,这个过程实质上是将原数据特征的一种“放大”与分离处理:

由28-14+1=15,
可知每行数据会切出:15*15=225个窗口,每个窗口14*14的规模
那么原来数据集的每行都变成了225行,即225个小窗口,窗口为14*14的特征块。即每行有14*14列(sliced_X)
这时sliced_Y还是int,被重复了225次,即每行还是对应一个y,指这个数据的正确标签数字。
然后将sliced_X,sliced_Y送给随机森林和完全随机森林训练,然后再用这俩森林对sliced_X跑出结果概率(十维,表示每个手写数字的概率),然后把俩十维文件合并为20维的,然后又把概率矩阵规模重设成了原始数据的行数。
注意这时数据格式是森林预测到的概率! 原始特征已经“看不到”了,即以后训练级联森林用的不是原数据特征,而是原数据经常森林预测出的概率,以概率为输入特征训练接下来的级联森林。这一点非常容易混淆。

重复上面的,把所有window值都跑一遍,整合概率矩阵,最后MutiScanning返回的行数等于初始数据行数,每行都是预测的概率,列数极多。然后将它送入级联森林。

训练级联森林就简单多了,每一层接受上一层的数据,并检验性能,没有提升就停下来。

论文介绍该算法对序列数据和图像数据效果较好,并相对于深度神经网络有一些优势,详见论文。

我下载来理解之后简单进行了测试,采用了kaggle平台上的手写数字的数据集,在参数都使用默认情形下,依然得到了不错的识别率。

这里给出测试方法:(需要先在上面链接下载gcForest代码,与本测试代码放在同目录下)

# -*-coding:utf-8-*-import pandas as pdimport numpy as npfrom sklearn.cross_validation import train_test_splitimport GCForestimport pandas as pdimport timeimport numpy as npfrom sklearn.model_selection import train_test_splitfrom GCForest import gcForestfrom sklearn.metrics import accuracy_scoredata= pd.read_csv('train.csv')#print(data.shape)ddata=np.array(data)x=ddata[:1400,1:785].copy()  #为快速看到测试结果,这里只试用1400条数据,可自行更改y=ddata[:1400,:1].copy()y=y.flatten()#print(x)#print(y.shape)X_train,X_test,y_train,y_test=train_test_split(x,y,test_size=0.25,random_state=9)#print(y_train)gcf = GCForest.gcForest(shape_1X=[28,28], window=[14,16], tolerance=0.0, min_samples_mgs=10, min_samples_cascade=7)gcf.fit(X_train, y_train)pred_X = gcf.predict(X_test)#print (pred_X)accuracy = accuracy_score(y_true=y_test, y_pred=pred_X)print ('gcForest accuracy:{}'.format(accuracy))

由于gcForest代码运行时十分耗内存,博主16g的内存不够用,当把kaggle上四万数据全部用上时,扫描窗口window不能设为合适值,window大约14左右合适,但目前内存限制只能设到23以上,并不能完全发挥gcForest的能力,kaggle上提交准确率有98.2%

每次得到的结果可能不同,这是因为其中随机森林有一定随机性,但总体差别不大。

后期调参工作应该可以再提升识别率。

阅读全文
0 0
原创粉丝点击