周志华新论文gcForest手写数字测试识别详解(Kaggle数据集)
来源:互联网 发布:培训学校收费软件 编辑:程序博客网 时间:2024/06/04 19:31
lamda实验室最近研发了新的学习算法gcForest,论文和代码在lamda网站上都有,给出的代码没有注释相当费解(大牛请无视),后又在github上找到民间大神实现的代码,较为简洁易懂,先贴上代码链接:https://github.com/pylablanche/gcForest
简要对gcForest算法处理数据的流程解释一下:
假设我们现在的数据就是kaggle平台上手写数字识别的数据格式,具体格式见https://www.kaggle.com/c/digit-recognizer/data
那么但数据规模是28*28,假设我们设置扫描窗口window为14*14:
先有多粒度扫描过程,这个过程实质上是将原数据特征的一种“放大”与分离处理:
由28-14+1=15,
可知每行数据会切出:15*15=225个窗口,每个窗口14*14的规模
那么原来数据集的每行都变成了225行,即225个小窗口,窗口为14*14的特征块。即每行有14*14列(sliced_X)
这时sliced_Y还是int,被重复了225次,即每行还是对应一个y,指这个数据的正确标签数字。
然后将sliced_X,sliced_Y送给随机森林和完全随机森林训练,然后再用这俩森林对sliced_X跑出结果概率(十维,表示每个手写数字的概率),然后把俩十维文件合并为20维的,然后又把概率矩阵规模重设成了原始数据的行数。
注意这时数据格式是森林预测到的概率! 原始特征已经“看不到”了,即以后训练级联森林用的不是原数据特征,而是原数据经常森林预测出的概率,以概率为输入特征训练接下来的级联森林。这一点非常容易混淆。
重复上面的,把所有window值都跑一遍,整合概率矩阵,最后MutiScanning返回的行数等于初始数据行数,每行都是预测的概率,列数极多。然后将它送入级联森林。
训练级联森林就简单多了,每一层接受上一层的数据,并检验性能,没有提升就停下来。
论文介绍该算法对序列数据和图像数据效果较好,并相对于深度神经网络有一些优势,详见论文。
我下载来理解之后简单进行了测试,采用了kaggle平台上的手写数字的数据集,在参数都使用默认情形下,依然得到了不错的识别率。
这里给出测试方法:(需要先在上面链接下载gcForest代码,与本测试代码放在同目录下)
# -*-coding:utf-8-*-import pandas as pdimport numpy as npfrom sklearn.cross_validation import train_test_splitimport GCForestimport pandas as pdimport timeimport numpy as npfrom sklearn.model_selection import train_test_splitfrom GCForest import gcForestfrom sklearn.metrics import accuracy_scoredata= pd.read_csv('train.csv')#print(data.shape)ddata=np.array(data)x=ddata[:1400,1:785].copy() #为快速看到测试结果,这里只试用1400条数据,可自行更改y=ddata[:1400,:1].copy()y=y.flatten()#print(x)#print(y.shape)X_train,X_test,y_train,y_test=train_test_split(x,y,test_size=0.25,random_state=9)#print(y_train)gcf = GCForest.gcForest(shape_1X=[28,28], window=[14,16], tolerance=0.0, min_samples_mgs=10, min_samples_cascade=7)gcf.fit(X_train, y_train)pred_X = gcf.predict(X_test)#print (pred_X)accuracy = accuracy_score(y_true=y_test, y_pred=pred_X)print ('gcForest accuracy:{}'.format(accuracy))
由于gcForest代码运行时十分耗内存,博主16g的内存不够用,当把kaggle上四万数据全部用上时,扫描窗口window不能设为合适值,window大约14左右合适,但目前内存限制只能设到23以上,并不能完全发挥gcForest的能力,kaggle上提交准确率有98.2%
每次得到的结果可能不同,这是因为其中随机森林有一定随机性,但总体差别不大。
后期调参工作应该可以再提升识别率。
- 周志华新论文gcForest手写数字测试识别详解(Kaggle数据集)
- 在Kaggle手写数字数据集上使用Spark MLlib的RandomForest进行手写数字识别
- kaggle-Digit Recognition(手写数字识别)
- 在Kaggle手写数字数据集上使用Spark MLlib的朴素贝叶斯模型进行手写数字识别
- kaggle入门篇一【手写数字识别】
- 报告论文:手写数字识别
- 经典手写数字mnist数据集识别
- Kaggle 手写识别题
- kaggle-手写字体识别
- 【Kaggle笔记】手写数字识别分类(线性支持向量机)
- Kaggle Digit Recognizer使用keras实现手写数字识别 A1
- Tensorflow深度学习笔记(五)--手写数字识别-MNIST数据测试
- 用TensorFlow做Kaggle“手写识别”达到98%准确率-详解
- kaggle Code :手写识别 TensorFlow
- kaggle+mnist手写字体识别
- python tensorflow 使用minist数据集实现手写数字识别
- TensorFlow学习笔记(3)--实现Softmax逻辑回归识别手写数字(MNIST数据集)
- kaggle的手写识别比赛(python sklearn-KNN)
- iPhone、iPad默认按钮样式问题
- Java 正则表达式
- 焦点轮播图代码详解!基础版本
- 463. Island Perimeter Difficulty : Easy
- 猫和老鼠
- 周志华新论文gcForest手写数字测试识别详解(Kaggle数据集)
- Java 方法
- 安卓开发-intent属性总结
- Linux把普通用户加入sudo组
- OpenStack之安装nova
- VMware + Linux + Xshell 连接环境设置(心得体会)
- Spring自动装配的方法
- linux 下awk 的使用
- Redis 使用入门