XGBoost源码阅读笔记(1)--代码逻辑结构
来源:互联网 发布:淘宝真皮女鞋店铺推荐 编辑:程序博客网 时间:2024/06/14 09:05
一. XGBoost简介
XGBoost(eXtreme Gradient Boosting)是基于GB(Gradient Boosting)模型框架实现的一个高效,便捷,可扩展的一个机器学习库。该库先由陈天奇在2014年完成v0.1版本之后开源到github[1]上,当前最新版本是v0.6。目前在各类相关竞赛中都可以看到其出现的身影,如kaggle[2],在2015年29个竞赛中,top3队伍发表的解决方案中有17个方案使用了XGBoost,而只有11个解决方案使用了深度学习;同时在2015KDDCup中top10队伍都使用了XGBoost[3]。由于其与GBDT(Gradient Boosting decision Tree)存在一定相似之处,网上也经常会有人将GBDT和XGBoost做个对比[4]。最近正好读了陈天奇的论文《XGBoost: A Scalable Tree Boosting System》[3], 从论文中可以看出XGBoost新颖之处在于:
1. 使用了正则化的目标函数,其加入的惩罚项会控制模型复杂度(叶子个数)和叶子结点的得分权重
图1-1 目标函数
2. 使用Shrinkage,通过一个因子η缩减每次最新生成树的权重,其目的是为了降低已生成的树对后续树的影响。
3. 支持列(特征)采样,该方式曾被用于随机森林。可以预防过拟合且加快模型训练速度。
4. 并行计算。Boost方式树是串行生成的,所以其在寻找树分裂点时候进行并行计算,加快模型训练速度。
在寻找分裂点时候论文中也提到多种方式:
1. 基本枚举贪婪搜索算法。该方式将特征按其值排序之后,枚举每个特征值作为其分裂点并计算该分裂点的增益,然后选择最大增益的分裂点
2. 近似贪婪搜索算法。该方式在寻找分裂点前会将所有的特征按其对应值进行排序后选择其百分位的点作为候选集合,在执行基本穷举贪婪搜索法。
3. 加权分位数法(weighted quantile sketch)。该方法可以用于对加权数据的处理。
4. 稀疏分裂点查找。可以加快模型对稀疏数据处理。
其与GBDT不同之一在于其对目标函数进行二阶泰勒展开,使用了二阶导数加快模型收敛速度。总的来说XGBoost受到欢迎最重要的一个因素在于其快速的训练过程。
二. 源码下载及编译
Linux上的源码下载和编译过程如下[5]:
git clone --recursive https://github.com/dmlc/xgboostcd xgboostmake
使用--recursive命令是因为XGBoost使用了作者自己编写的分布式计算库,通过这个命令可以下载对应的库,编译好之后就可以开始阅读源码了,XGBoost主要代码目录结构如下:
|--xgboost |--include |--xgboost //定义了xgboost相关的头文件 |--src |--c_api |--common //一些通用文件,如对配置文件的处理 |--data //使用的数据结构,如DMatrix |--gbm //定义了若分类器,如gbtree和gblinear |--metric //定义评价函数 |--objective //定义目标函数 |--tree //对树的一些列操作
三. 源码逻辑结构
程序的执行入口在cli_main.cc文件中
//cli_main.cc|--main() |--CLIRunTask() |--CLIParam::configure() |--switch(param.task) { case kTrain: CLITrain(param);break; case KDumpModel: CLIDumpModel(param);break; case KPredict: CLIPredict(param);break; }
在main函数中只调用了CLIRunTask()函数,在该函数中可以看出,程序通过函数configure()解析配置文件后,根据参数task选择对应的执行函数。我们这里主要看下训练函数CLITrain();
//cli_main.cc|--CLITrain() |--DMatrix::Load() |--Learner::Create() |--Learner::Configure() |--Learner::InitModel() |--for (int iter = 0; iter < max_iter; ++iter) { Learner::UpdateOneIter(); Learner::EvalOneIter(); }
在CLI函数中, 先是将训练数据加载到内存中,然后开始创建Learner类实例, 接着调用Learner的configure函数配置参数,调用InitModel()初始化模型。然后就开始XGboost的Boosting训练,主要调用的是Learner的UpdateOneIter()函数。
//learner.cc|--UpdateOneIter() |--learner::LazyInitDMatrix() |--learner::PredictRaw() |--ObjFunction::GetGradient() |--GradientBooster::DoBoost()
在每次迭代过程中,LazyInitDMatrix()先初始化需要用到的数据结构。GetGradient()获取目标函数的一阶导和二阶导,最后DoBoost()执行Boost操作生成一棵回归树。Class GradientBoost是一个抽象类,他定义了Gradient Boost的抽象接口。其派生出的两个类Class GBTree和 Class GBLinear 分别对应着配置文件里面的参数“gbtree”和“gblinear”, Class GBTree主要使用的回归树作为其弱分类器,而Class GBLinear使用的是线性回归或逻辑回归作为其弱分类器。
Class GBTree用的比较多,其DoBoost()函数执行的操作如下:
//gbtree.cc|--GBTree::DoBoost() |--GBTree::BoostNewTrees() |--GBTree::InitUpdater() |--TreeUpdater::Update()
DoBoost()调用了BoostNewTrees()函数。在BoostNewTrees()中先初始化了TreeUpdater实例,在调用其Update函数生成一棵回归树。TreeUpdater是一个抽象类,根据使用算法不同其派生出许多不同的Updater,这些Updater都在src/tree目录下。
|--src |--tree |--updater_basemaker-inl.h |--updater_colmaker.cc |--updater_skmaker.cc |--updater_refresh.cc |--updater_prune.cc |--updater_hismaker.cc |--updater_fast_hist.cc
文件updater_basemaker-inl.h中定义了一个派生自TreeUpdater的类BaseMaker。Class ColMaker使用的是基本枚举贪婪搜索算法,通过枚举所有的特征来寻找最佳分裂点;Class SkMaker派生自BaseMaker,使用近似的sketch方法寻找最佳分裂点;Class TreeRefresher用于刷新数据集上树的统计信息和叶子值;Class TreePruner是树的剪枝操作;Class HistMaker使用的是直方图法,该方法在论文中并没有提到,所以也不是很清楚。
至此便可以大致了解XGBoost源码的逻辑结构,目前源码只看到这里。等看了各算法的具体实现之后再在后续文章中写其具体实现细节。
四. 参考文献
[1]. https://github.com/dmlc/xgboost
[2]. https://www.kaggle.com
[3]. Tianqi Chen and Carlos Guestrin. XGBoost: A Scalable Tree Boosting System. In 22nd SIGKDD Conference on Knowledge Discovery and Data Mining, 2016 网址:https://arxiv.org/abs/1603.02754
[4]. https://www.zhihu.com/question/41354392
[5]. http://xgboost.readthedocs.io/en/latest/build.html
- XGBoost源码阅读笔记(1)--代码逻辑结构
- MySQL源码阅读笔记之代码结构
- XGBoost源码阅读笔记(2)--树构造之Exact Greedy Algorithm
- [hadoop源码阅读][1]-源码目录结构
- [hadoop源码阅读][1]-源码目录结构
- XGBoost代码走读分析笔记
- 代码分析:NASM源码阅读笔记
- x264源码阅读笔记1
- ActiveAndroid 源码阅读笔记 (1)
- ViewFlow 源码阅读笔记(1)
- flashsim源码阅读笔记1
- tinySLAM代码逻辑结构
- TensorFlow0.8源码阅读 -- 代码目录结构讲解
- TensorFlow0.8源码阅读 -- 代码目录结构讲解
- 非典型2D游戏引擎 Orx 源码阅读笔记(1) 总体结构
- jQuery源码阅读笔记——整体结构
- x264代码阅读笔记1
- vnc 代码阅读笔记1
- java 集合 Map的遍历方式
- LeetCode-20-Valid Parentheses(有效的括号)
- NAT技术 与 代理服务器
- 正则进阶之旅-五条
- 关于我的CSDN博客的一些要说的话
- XGBoost源码阅读笔记(1)--代码逻辑结构
- C++之必须返回对象时候,别妄想返回其reference(21)---《Effective C++》
- Spring Batch 注册监听器
- Error:Execution failed for task ':app:transformClassesWithDexForRelease'. > com.android.build.api.tr
- 交叉编译GCC for arm
- HPU2017-2016级暑期集训练习赛
- 极光推送图标遇到问题及退出极光推送帐号
- python_selenium(五)
- 对 IIC 总线的理解、调用函数以及常见面试问题