XGBoost源码阅读笔记(2)--树构造之Exact Greedy Algorithm
来源:互联网 发布:阿里大数据查询 编辑:程序博客网 时间:2024/06/08 08:09
在上一篇《XGBoost源码阅读笔记(1)--代码逻辑结构》中向大家介绍了XGBoost源码的逻辑结构,同时也简单介绍了XGBoost的基本情况。本篇将继续向大家介绍XGBoost源码是如何构造一颗回归树,不过在分析源码之前,还是有必要先和大家一起推导下XGBoost的目标函数。本次推导过程公式截图主要摘抄于陈天奇的论文《XGBoost:A Scalable Tree Boosting System》。在后续的源码分析中,会省略一些与本篇无关的代码,如并行化,多线程。
一、目标函数优化
XGBoost和以往的GBT(Gradient Boost Tree)不同之一在于其将目标函数进行了二阶泰勒展开,在模型训练过程中使用了二阶导数加快其模型收敛速度。与此同时,为了防止模型过拟合其给目标函数加上了控制模型结构的惩罚项。
图1-1 目标函数
目标函数主要有两部分组成。第一部分是表示模型的预测误差;第二部分是表示模型结构。
当模型预测误差越大,树的叶子个数越多,树的权重越大,目标函数就越大。我们的优化目标是使目标函数尽可能的小,这样在降低预测误差的同时也会减少树叶子的个数以及降低叶子权重。这也正符合机器学习中的“奥卡姆剃刀”原则,即选择与经验观察一致的最简单假设。
图1-1的目标函数由于存在以函数为参数的模型惩罚项导致其不能使用传统的方式进行优化,所以将其改写成如下形式
图1-2 改变形式的目标函数
图1-2与图1-1的区别在于图1-1是通过整个模型去优化函数,而图1-2的优化目标是每次迭代过程中构造一个使目标函数达到最小值的弱分类器,从这个过程中就可以看出图1-2使用的是贪婪算法。将图1-2中的预测误差项在处进行二阶泰勒展开:
图1-3 二阶泰勒展开
并省去常数项
图1-4 省去常数项
图1-4就是每次迭代过程中简化的目标函数。我们的目标是在第t次迭代过程中获得一个使目标函数达到最小值的最优弱分类器,即。在这里累加项n是样本实例的个数,为了使编码更加方便,定义一个新的变量表示表示叶子j的所有样本实例
图1-5 新的变量
同时展开目标函数的模型惩罚项,并以叶子为纬度可以改写成
图1-6 以叶子为纬度的目标函数
这里函数f是将对应实例归类到对应的叶子下,并返回该实例在当前叶子下的权重w。图1-6对叶子权重w求导,便得出最优的叶子权重w
图1-7最优的叶子权重
与此同时将权重代入目标函数,并且省去常量,便得到了目标函数的解析式
图1-8 目标函数的解析式
我们的目标便是极小化该目标函数解析式。目标函数的解析式可以通过图1-9清晰形象的描绘出来
图1-9 目标函数的解析式计算过程
从图1-9可以清晰看出目标函数解析式的计算过程。目标函数的结果可以用来评价模型的好坏。这样在模型训练过程中,当前的叶子结点是否需要继续分裂主要就看分裂后的增益损失loss_change。
图1-10 分裂增益
增益损失loss_change的计算公式如图1-10所示,它是由该结点分裂后的左孩子增益加上右孩子增益减去该父结点的增益。这样在选择分裂点时候就是选择增益损失最大的分裂点。而寻找最佳分裂点是一个非常耗时的过程,上一篇《XGBoost源码阅读笔记(1)--代码逻辑结构》介绍了几种XGBoost使用的分裂算法,这里选择其中最简单的Exact Greedy Algorithm进行讲解:
图1-11 Exact Greedy Algorithm
图1-11算法的大意是遍历每个特征,在每个特征中选择该特征下的每个值作为其分裂点,计算增益损失。当遍历完所有特征之后,增益损失最大的特征值将作为其分裂点。由此可以看出这其实就是一种穷举算法,而整个树构造过程最耗时的过程就是寻找最优分裂点的过程。但是由于该算法简单易于理解,所以就以该算法来向大家介绍XGBoost源码树构造的实现过程。
如果对推导过程读起来比较吃力的话也没关系,这里主要需要记住的是每个结点增益和权值的计算公式。增益是用来决定当前结点是否需要继续分裂下去,而结点权值的线性组合即是模型最终的输出值。所以只要记住这两个公式就不会影响源码的阅读。
二、源码分析
1) 代码逻辑结构回顾
在上一篇结尾的时候说过源码最终调用过程如下:
//gbtree.cc|--GBTree::DoBoost() |--GBTree::BoostNewTrees() |--GBTree::InitUpdater() |--TreeUpdater::Update()
这里简化后的源码如下:
//gbtree.cc line:452BoostNewTrees(const std::vector<bst_gpair> &gpair, DMatrix *p_fmat, int bst_group, std::vector<std::unique_ptr<RegTree> >* ret) { this->InitUpdater(); std::vector<RegTree*> new_tress; for(auto& up: updaters){ up->Update(gpair,p_fmat, new_trees); }}
gpair是一个vector向量,保存了对应样本实例的一阶导数和二阶导数。p_fmat是一个指针,指向对应样本实例的特征,new_trees用于存储构造好的回归树。
InitUpdater()是为了初始化updaters, 在上一篇文章也说过updaters是抽象类Class TreeUpdater的指针对象,定义了基本的Init和Update接口,该抽象的派生类定义了一系列树构造和剪枝方法。这里主要介绍其派生类Class ColMaker,该类使用的即使我们前面介绍的Exact Greedy Algorithm
2) Class ColMaker 数据结构介绍
在ClassColMaker 定义了一些数据结构用于辅助树的构造。
//updater_colmaker.cc line:755const TrainParam& param; //训练参数,即我们设置的一些超参数std::vector<int> position; //当前样本实例在回归树结中对应结点的索引std::vector<NodeEntry> snode; //回归树中的结点std::vector<int> qexpand_; //保存将有可能分类的结点的索引
XGBoost的树构造类似于BFS(Breadth First Search),它是一层一层的构造树结点。所以需要一个队列qexpand_用来保存当前层的结点,这些结点会根据增益损失loss_change决定是否需要分裂形成下一层的结点。
3) Class ColMaker 树构造源码
//updater_colmaker.cc line:29void ColMaker::Update(...){ for(size_t i = 0; i < trees.size(); ++){ Builder builder(param); builder.Update(gpair, dmat, trees[i]); }}
在Class ColMaker中定义了一个Class Builder类,所有的构造过程都由这个类完成。
//updater_colmaker.cc line:89void ColMaker::Builder::Update(...){ this -> InitData(...); //初始化Builder参数 // 初始化树根结点的权值和增益 this -> InitNewNode(gpair, *p_fmat,*p_tree); for( int depth = 0; depth < param.max_depth; ++depth) { //给队列中的当层结点寻找分裂特征,构造出树的下一层 this->FindSplit(depth, qexpand_, gpair, p_fmat, p_tree); //将当层各个非叶子结点中的样本实例分类到下一层的各个结点中 this->ResetPosition(); //更新队列,存储下一个层结点 this->UpdateQueueExpand(); //计算队列中下一层结点的权值和增益 this->InitNewNode(); //如果当前队列中没有候选分裂点,就退出循环 If(qexpand_.size() == 0) break; } //由于树的深度限制,将队列中剩下结点都设置为树的叶子 for(size_t i = 0; i < qexpand_.szie(); ++i) { ... } //记录构造好的回归树的一些辅助统计信息 ...}
在以上代码中核心部分就是第一个循环里面的四个函数。我们首先来看下Builder::InitNewNode是如何初始化结点的增益和权值。
1. Builder::InitNewNode()
//updater_colmaker.cc|--Builder::InitNewNode() |--for(size_t j = 0; j < qexpand_.size(); ++j) |--{ |-- snode[qexpand[j]].root_gain = CalGain(...) |-- snode[qexpand[j]].weight = CalWeight(...) |--}
这里点的root_gain就是前面说的结点增益,将用于判断该点是否需要分裂。weigtht就是当前点的权值,最终模型输出就是叶子结点weight的线性组合。CalGain()和CalWeight()是两个模版函数,其简化的源码如下:
//param.h line:242Template<typename TrainingParams, typename T>T CalGain(const TrainingParams &p, T sum_grad, T sum_hess){ return (sum_grad * sum_grad)/( sum_hess + p.reg_lambda);}//param.h line:275Template<typename TrainingParams, typename T>T CalWeight(const TrainingParams &p, T sum_grad, T sum_hess){ return -sum_grad /( sum_hess + p.reg_lambda);}
以上两个函数就是实现了我们一开始推导的两个公式,即计算结点的增益和权重。在初始化队列中的结点后,就需要对队列中的每个结点遍历寻找最优的分裂属性。
2. XGBoost::Builder::FindSplit()
//updater_colmaker.cc|--XGBoost::Builder::FindSplit() |--//寻找特征的最佳分裂值 |--for(size_t i = 0; i< feature_num; i++) |--{ |-- XGBoost::Builder::UpdateSolution() | --XGBoost::Builder::EnumerateSplit() |--}分裂过程最终调用了EnumerateSplit()函数,为了便于理解对代码变量名做了修改,其简化的代码如下
//updater_colmaker.cc line:508void EnumerateSplit(...){ //建立个临时变量temp用来保存结点信息 //空间大小为队列qexpand_中结点的最大索引 vector<TStats> temp( std::max(qexpand_) + 1); TStats left_child(param) //结点分裂后左孩子的统计信息 //遍历当前特征的所有值 for(const ColBatch::Entry * it = begin; it != end; it += d_step){ //得到当前特征值所对应的样本实例索引和特征值 const int rIndex = it -> index; const int fValue = it->value; //根据当前样本索引得到其分类到的结点索引 const int node_id = position[rIndex ]; //结点分裂后右孩子的统计信息 TStats & right_child = temp[node_id] //以当前特征值为分裂阈值,将当前样本归类到左孩子 left_child = snode[node_id].stats - right_child; //计算增益损失 int loss_change= CalcSplitGain(param, left_child, right_child) - snode[node_id].root_gain; //记录下最好的特征值分裂阈值,该阈值是左右孩子相邻特征值的中间值 right_child.best.Update(loss_change, feature_id , 0.5 * (fValue + right_child.left_value) ); //将当前样本实例归类到右孩子结点 right_child.add(gpair, info , rIndex) }}从上述代码可以很清晰看出整个代码的流程思路就是之前介绍的Exact Greedy Alogrithm. 这里需要说明寻找分裂点有两个方向,一个是从左到右寻找,一个是从右到左寻找。上述代码只展示了一个方向的寻找过程。在寻找特征分裂阈值的时候分裂增益的计算函数是CalcSplitGain(),其具体代码如下:
//param.h line:365double CalcSplitGain(const TrainParam ¶m, GradStats left, GradStats right) const { return left.CalcGain(param) + right.CalcGain(param);}上述代码就是简单将左孩子和右孩子的增益相加,而增益损失loss_change就是将左右孩子相加的增益减去其父节点的增益。
3. XGboost::Builder::ResetPosition()
在寻找到当前层各个结点的分裂阈值之后,便可以在对应结点上构造其左右孩子来增加当前树的深度。当树的深度增加了,就需要将分类到当前层非叶子结点的样本实例分类到下一层的结点中。这个过程就是通过ResetPosition()函数完成的。
4. XGboost::Builder::UpdateQueueExpand()
XGboost::Builder::UpdateQueueExpand()函数更新qexpand_队列中的结点为下一层结点,然后在调用XGboost::Builder::InitNewNode()更新qexpand_中结点的权值和增益以便下一次循环
三、总结
本篇主要详细叙述了XGBoost使用Exact Greedy Algorithm构造树的方法,并分析了对应的源码。在分析源码过程中为了便于理解对代码做了一些简化,如省去了其中多线程,并行化的操作,并修改了一些变量名。在上述的树构造完成之后,还需要对树进行剪枝操作以防止模型过拟合。由于篇幅所限,这里就不再介绍剪枝操作。本篇文章只是起一个抛砖引玉的引导作用,想要对XGBoost实现细节有更加深刻理解,还需要去阅读XGBoost源码,毕竟有些东西用文字描述远不如用代码描述清晰。最后欢迎大家一起来讨论。
- XGBoost源码阅读笔记(2)--树构造之Exact Greedy Algorithm
- XGBoost源码阅读笔记(1)--代码逻辑结构
- 贪心算法(Greedy Algorithm)之最小生成树 克鲁斯卡尔算法(Kruskal's algorithm)
- greedy algorithm
- Greedy algorithm
- USACO 之 Section 1.3 Greedy Algorithm
- 算法之贪心算法(greedy algorithm)
- exact algorithm 精确算法
- 最大生成树(Greedy Algorithm)
- 【贪心法求解最小生成树之Kruskal算法详细分析】---Greedy Algorithm for MST
- 贪心算法(Greedy Algorithm)之霍夫曼编码(Huffman codes)
- 贪心算法(Greedy Algorithm)之霍夫曼编码(Huffman codes)
- hdu 1052 (greedy algorithm)
- greedy algorithm DEMO
- Section 1.3 Greedy Algorithm
- Greedy Algorithm--Algorithms
- x264源码阅读笔记2
- ActiveAndroid 源码阅读笔记 (2)
- Unity 数据库的简单使用
- 树形dp(洛谷1040 加分二叉树noip2003提高组第三题)
- 字符串处理排序(洛谷1012 拼数)
- dp+高精度 (洛谷1005 矩阵取数游戏 NOIP 2007 提高第三题)
- 转 浅谈用极大化思想解决最大子矩形问题
- XGBoost源码阅读笔记(2)--树构造之Exact Greedy Algorithm
- 求最大子矩阵悬线法(codevs 1159 最大全0子矩阵)
- JS学习第四天
- java操作Redis数据库的redis工具,RedisUtil,jedis工具JedisUtil,JedisPoolUtil
- 悬线法求最大子矩阵(洛谷P1169 [ZJOI2007]棋盘制作 bzoj1057)
- dfs(洛谷1019 单词接龙NOIp2000提高组第三题)
- lca(洛谷P3379 最近公共祖先(LCA))
- poj 2251 Dungeon Master(多起点bfs)
- LCA 最近公共祖先