LambdaMART简介:lambda计算及Regression Tree训练

来源:互联网 发布:北京软件行业协会电话 编辑:程序博客网 时间:2024/05/29 15:12

part1: lambda计算(来源:http://www.mamicode.com/info-detail-149823.html)

学习Machine Learning,阅读文献,看各种数学公式的推导,其实是一件很枯燥的事情。有的时候即使理解了数学推导过程,也仍然会一知半解,离自己写程序实现,似乎还有一道鸿沟。所幸的是,现在很多主流的Machine Learning方法,网上都有open source的实现,进一步的阅读这些源码,多做一些实验,有助于深入的理解方法。

Ranklib就是一套优秀的Learning to Rank领域的开源实现,其主页在:http://people.cs.umass.edu/~vdang/ranklib.html,从主页中可以看到实现了哪些方法。其中由微软发布的LambdaMART是IR业内常用的Learning to Rank模型,本文介绍RanklibV2.1(当前最新的时RanklibV2.3,应该大同小异)中的LambdaMART实现,用以帮助理解paper中阐述的方法。

LambdaMART.java中的LambdaMART.learn()是学习流程的管控函数,学习过程主要有下面四步构成:

1. 计算deltaNDCG以及lambda;

2. 以lambda作为label训练一棵regression tree;

3. 在tree的每个叶子节点通过预测的regression lambda值还原出gamma,即最终输出得分;

4. 用3的模型预测所有训练集合上的得分(+learningRate*gamma),然后用这个得分对每个query的结果排序,计算新的每个query的base ndcg,以此为基础回到第1步,组成森林。

重复这个步骤,直到满足下列两个收敛条件之一:

1. 树的个数达到训练参数设置;

2. Random Forest在validation集合上没有变好。

下面用一组实际的数据来说明整个计算过程,假设我们有10个query的训练数据,每个query下有10个doc,每个q-d对有10个feature,如下:

 1 0 qid:1830 1:0.002736 2:0.000000 3:0.000000 4:0.000000 5:0.002736 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.000000 2 0 qid:1830 1:0.025992 2:0.125000 3:0.000000 4:0.000000 5:0.027360 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.000000 3 0 qid:1830 1:0.001368 2:0.000000 3:0.000000 4:0.000000 5:0.001368 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.000000 4 1 qid:1830 1:0.188782 2:0.375000 3:0.333333 4:1.000000 5:0.195622 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.000000 5 1 qid:1830 1:0.077975 2:0.500000 3:0.666667 4:0.000000 5:0.086183 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.000000 6 0 qid:1830 1:0.075239 2:0.125000 3:0.333333 4:0.000000 5:0.077975 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.000000 7 1 qid:1830 1:0.079343 2:0.250000 3:0.666667 4:0.000000 5:0.084815 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.000000 8 1 qid:1830 1:0.147743 2:0.000000 3:0.000000 4:0.000000 5:0.147743 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.000000 9 0 qid:1830 1:0.058824 2:0.000000 3:0.000000 4:0.000000 5:0.058824 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000010 0 qid:1830 1:0.071135 2:0.125000 3:0.333333 4:0.000000 5:0.073871 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000011 0 qid:1837 1:0.004065 2:0.000000 3:0.500000 4:0.000000 5:0.000000 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000012 0 qid:1837 1:0.459350 2:0.000000 3:0.000000 4:1.000000 5:0.455285 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000013 0 qid:1837 1:0.060976 2:0.333333 3:0.500000 4:0.000000 5:0.065041 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000014 0 qid:1837 1:0.093496 2:0.000000 3:0.000000 4:0.000000 5:0.085366 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000015 0 qid:1837 1:0.195122 2:0.000000 3:0.000000 4:0.000000 5:0.186992 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000016 0 qid:1837 1:0.036585 2:0.333333 3:0.500000 4:0.000000 5:0.040650 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000017 0 qid:1837 1:0.032520 2:0.000000 3:0.000000 4:0.000000 5:0.024390 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000018 0 qid:1837 1:0.073171 2:0.000000 3:0.000000 4:0.000000 5:0.065041 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000019 0 qid:1837 1:0.024390 2:1.000000 3:0.500000 4:1.000000 5:0.048780 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000020 0 qid:1837 1:0.024390 2:0.333333 3:0.500000 4:1.000000 5:0.032520 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000021 0 qid:1840 1:0.000000 2:0.000000 3:0.000000 4:0.000000 5:0.000000 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000022 1 qid:1840 1:0.007364 2:0.200000 3:1.000000 4:0.500000 5:0.013158 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000023 1 qid:1840 1:0.097202 2:0.000000 3:0.000000 4:0.000000 5:0.096491 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000024 2 qid:1840 1:0.169367 2:0.000000 3:0.500000 4:0.000000 5:0.169591 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.00000025 ......

 

为了简便,省略了余下的数据。上面的数据格式是按照Ranklib readme中要求的格式组织(类似于svmlight),除了行号之外,第一列是q-d对的实际label(人标注数据),第二列是qid,后面10列都是feature。

这份数据每组qid中的doc初始顺序可以是随机的,也可以是从实际的系统中获得的当前顺序。总之这个是计算ndcg的初始状态。对于qid=1830,它的10个doc的初始顺序的label序列是:0, 0, 0, 1, 1, 0, 1, 1, 0, 0(虽然这份序列中只有label值为0和1的,实际中也会有2,3等,由自己的标注标准决定)。我们知道dcg的计算公式是:

bubuko.com,布布扣

i表示当前doc在这个qid下的位置(从1开始,避免分母为0),label(i)是doc(i)的标注值。而一个query的dcg则是其下所有doc的加和:

bubuko.com,布布扣

 根据上式可以计算初始状态下每个qid的dcg: 

bubuko.com,布布扣

要计算ndcg,还需要计算理想集的dcg,将初始状态按照label排序,qid=1830得到的序列是1,1,1,1,0,0,0,0,0,0,计算dcg:

bubuko.com,布布扣

两者相除得到初始状态下qid=1830的ndcg:

bubuko.com,布布扣

 

下面要计算每一个doc的deltaNDCG,公式如下:

bubuko.com,布布扣

deltaNDCG(i,j)是将位置i和位置j的位置互换后产生的ndcg变化(其他位置均不变),显然有相同label的deltaNDCG(i,j)=0。

在qid=1830的初始序列0, 0, 0, 1, 1, 0, 1, 1, 0, 0,由于前3的label都一样,所以deltaNDCG(1,2)=deltaNDCG(1,3)=0,不为0的是deltaNDCG(1,4), deltaNDCG(1,5), deltaNDCG(1,7), deltaNDCG(1,8)。

将1,4位置互换,序列变为1, 0, 0, 0, 1, 0, 1, 1, 0, 0,计算得到dcg=2.036,整个deltaNDCG(1,4)的计算过程如下:

bubuko.com,布布扣

同样过程可以计算出deltaNDCG(1,5)=0.239, deltaNDCG(1,7)=0.260, deltaNDCG(1,8)=0.267等。

进一步,要计算lambda(i),根据paper,还需要ρ值,ρ可以理解为doci比docj差的概率,其计算公式为:

bubuko.com,布布扣

Ranklib中直接取σ=1(σ的值决定rho的S曲线陡峭程度),如下图,蓝,红,绿三种颜色分别对应σ=1,2,4时ρ函数的曲线情形(横坐标是si-sj):

bubuko.com,布布扣

初始时,模型为空,所有模型预测得分都是0,所以si=sj=0,ρij≡1/2,lambda(i,j)的计算公式为:

 bubuko.com,布布扣

上式为Ranklib中实际使用的公式,而在paper中,还需要再乘以-σ,在σ=1时,就是符号正好相反,这两种方式应该是等价的,符号并不影响模型训练结果。而:

bubuko.com,布布扣

计算lambda(1),由于label(1)=0,qid=1830中的其他doc的label都大于或者等于0,所以lamda(1)的计算中所有的lambda(1,j)都为负项。将之前计算的各deltaNDCG(1,j)代入,且初始状态下ρij≡1/2,所以:

bubuko.com,布布扣

可以计算出初始状态下qid=1830各个doc的lambda值,如下:

 1 qId=1830    0.000   0.000   0.000   -0.111  -0.120  0.000   -0.130  -0.134  0.000   0.000   lambda(1): -0.495 2 qId=1830    0.000   0.000   0.000   -0.039  -0.048  0.000   -0.058  -0.062  0.000   0.000   lambda(2): -0.206 3 qId=1830    0.000   0.000   0.000   -0.014  -0.022  0.000   -0.033  -0.036  0.000   0.000   lambda(3): -0.104 4 qId=1830    0.111   0.039   0.014   0.000   0.000   0.015   0.000   0.000   0.025   0.028   lambda(4): 0.231  5 qId=1830    0.120   0.048   0.022   0.000   0.000   0.006   0.000   0.000   0.017   0.019   lambda(5): 0.231  6 qId=1830    0.000   0.000   0.000   -0.015  -0.006  0.000   -0.004  -0.008  0.000   0.000   lambda(6): -0.033 7 qId=1830    0.130   0.058   0.033   0.000   0.000   0.004   0.000   0.000   0.006   0.009   lambda(7): 0.240  8 qId=1830    0.134   0.062   0.036   0.000   0.000   0.008   0.000   0.000   0.003   0.005   lambda(8): 0.247  9 qId=1830    0.000   0.000   0.000   -0.025  -0.017  0.000   -0.006  -0.003  0.000   0.000   lambda(9): -0.05110 qId=1830    0.000   0.000   0.000   -0.028  -0.019  0.000   -0.009  -0.005  0.000   0.000   lambda(10): -0.061

 上表中每一列都是考虑了符号的lamda(i,j),即如果label(i)<label(j),则为负值,反之为正值,每行结尾的lamda(i)是前面的加和,即为最终的lambda(i)。

可以看到,lambda(i)在系统中表达了doc(i)上升或者下降的强度,label越高,位置越后,lambda(i)为正值,越大,表示趋向上升的方向,力度也越大;label越小,位置越靠前,lambda(i)为负值,越小,表示趋向下降的方向,力度也大(lambda(i)的绝对值表达了力度。)

然后Regression Tree开始以每个doc的lamda值为目标,训练模型。


part 2:Regression Tree训练(来源:http://www.cnblogs.com/wowarsenal/p/3906081.html)

上一节中介绍了 λλ 的计算,lambdaMART就以计算的每个doc的 λλ 值作为label,训练Regression Tree,并在最后对叶子节点上的样本 lambdalambda 均值还原成 γγ ,乘以learningRate加到此前的Regression Trees上,更新score,重新对query下的doc按score排序,再次计算deltaNDCG以及 λλ ,如此迭代下去直至树的数目达到参数设定或者在validation集上不再持续变好(一般实践来说不在模型训练时设置validation集合,因为validation集合一般比训练集合小很多,很容易收敛,达不到效果,不如训练时一步到位,然后另起test集合做结果评估)。

 

其实Regression Tree的训练很简单,最主要的就是决定如何分裂节点。lambdaMART采用最朴素的最小二乘法,也就是最小化平方误差和来分裂节点:即对于某个选定的feature,选定一个值val,所有<=val的样本分到左子节点,>val的分到右子节点。然后分别对左右两个节点计算平方误差和,并加在一起作为这次分裂的代价。遍历所有feature以及所有可能的分裂点val(每个feature按值排序,每个不同的值都是可能的分裂点),在这些分裂中找到代价最小的。

举个栗子,假设样本只有上一节中计算出 λλ 的那10个:

复制代码
 1 qId=1830 features and lambdas 2 qId=1830    1:0.003 2:0.000 3:0.000 4:0.000 5:0.003 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(1):-0.495 3 qId=1830    1:0.026 2:0.125 3:0.000 4:0.000 5:0.027 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(2):-0.206 4 qId=1830    1:0.001 2:0.000 3:0.000 4:0.000 5:0.001 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(3):-0.104 5 qId=1830    1:0.189 2:0.375 3:0.333 4:1.000 5:0.196 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(4):0.231 6 qId=1830    1:0.078 2:0.500 3:0.667 4:0.000 5:0.086 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(5):0.231 7 qId=1830    1:0.075 2:0.125 3:0.333 4:0.000 5:0.078 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(6):-0.033 8 qId=1830    1:0.079 2:0.250 3:0.667 4:0.000 5:0.085 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(7):0.240 9 qId=1830    1:0.148 2:0.000 3:0.000 4:0.000 5:0.148 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(8):0.24710 qId=1830    1:0.059 2:0.000 3:0.000 4:0.000 5:0.059 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(9):-0.05111 qId=1830    1:0.071 2:0.125 3:0.333 4:0.000 5:0.074 6:0.000 7:0.000 8:0.000 9:0.000 10:0.000    lambda(10):-0.061
复制代码

上表中除了第一列是qId,最后一列是lambda外,其余都是feature,比如我们选择feature(1)的0.059做分裂点,则左子节点<=0.059的doc有: 1, 2, 3, 9;而>0.059的被安排到右子节点,doc有4, 5, 6, 7, 8, 10。由此左右两个子节点的lambda均值分别为:

 

        λL¯=λ1+λ2+λ3+λ94=0.4950.2060.1040.0514=0.214λL¯=λ1+λ2+λ3+λ94=−0.495−0.206−0.104−0.0514=−0.214

        λR¯=λ4+λ5+λ6+λ7+λ8+λ106=0.231+0.2310.033+0.240+0.2470.0616=0.143λR¯=λ4+λ5+λ6+λ7+λ8+λ106=0.231+0.231−0.033+0.240+0.247−0.0616=0.143

 

继续计算左右子节点的平方误差和:

 

        sL=iL(λiλL¯)2=(0.495+0.214)2+(0.206+0.214)2+(0.104+0.214)2+(0.051+0.214)2=0.118sL=∑i∈L(λi−λL¯)2=(−0.495+0.214)2+(−0.206+0.214)2+(−0.104+0.214)2+(−0.051+0.214)2=0.118

        sR=iR(λiλR¯)2=(0.2310.143)2+(0.2310.143)2+(0.0330.143)2+(0.2400.143)2+(0.2470.143)2+(0.0160.143)2=0.083sR=∑i∈R(λi−λR¯)2=(0.231−0.143)2+(0.231−0.143)2+(−0.033−0.143)2+(0.240−0.143)2+(0.247−0.143)2+(0.016−0.143)2=0.083

 

因此将feature(1)的0.059的均方差(分裂代价)是:

 

        Cost0.059@feature(1)=sL+sR=0.118+0.083=0.201Cost0.059@feature(1)=sL+sR=0.118+0.083=0.201

 

我们可以像上面那样遍历所有feature的不同值,尝试分裂,计算Cost,最终选择所有可能分裂中最小Cost的那一个作为分裂点。然后将 sLsL 和 sRsR 分别作为左右子节点的属性存储起来,并把分裂的样本也分别存储到左右子节点中,然后维护一个队列,始终按平方误差和 s 降序插入新分裂出的节点,每次从该队列头部拿出一个节点(并基于这个节点上的样本)进行分裂(即最大均方差优先分裂),直到树的分裂次数达到参数设定(训练时传入的leaf值,叶子节点的个数与分裂次数等价)。这样我们就训练出了一棵Regression Tree。

 

上面讲述了一棵树的标准分裂过程,需要多提一点的是,树的分裂还有一个参数设定:叶子节点上的最少样本数,比如我们设定为3,则在feature(1)处,0.001和0.003两个值都不能作为分裂点,因为用它们做分裂点,左子树的样本数分别是1和2,均<3。叶子节点的最少样本数越小,模型则拟合得越好,当然也容易过拟合(over-fitting);反之如果设置得越大,模型则可能欠拟合(under-fitting),实践中可以使用cross validation的办法来寻找最佳的参数设定。



0 0