Python实现回归树

来源:互联网 发布:赌球外围软件 编辑:程序博客网 时间:2024/06/12 11:26

 

正文

----------------------------------------------------------------------------------------

本系列文章为《机器学习实战》学习笔记,内容整理自书本,网络以及自己的理解,如有错误欢迎指正。

源码在Python3.5上测试均通过,代码及数据 --> https://github.com/Wellat/MLaction

----------------------------------------------------------------------------------------

回到顶部

1、连续和离散型特征的树的构建 

决策树算法主要是不断将数据切分成小数据集,直到所有目标变量完全相同,或者数据不能再切分为止。它是一种贪心算法,并不考虑能否达到全局最优。前面介绍的用ID3构建决策树的算法每次选取当前最佳的特征来分割数据,并按照该特征的所有可能取值来划分,这种切分过于迅速,且不能处理连续性特征。另外一种方法是二元切分法,每次把数据集切成两份,如果数据的某特征等于切分所要求的值,那么这些数据就进入树的左子树,反之右子树。二元切分法可处理连续型特征,节省树的构建时间。

这里依然使用字典来存储树的数据结构,该字典将包含以下4个元素:

  • 待切分的特征
  • 待切分的特征值
  • 右子树,不需切分时,也可是单个值
  • 左子树,右子树类似

本章将构建两种树:第一种是第2节的回归树(regression tree),其每个叶节点包含单个值;第二种是第3节的模型树(model tree),其每个叶节点包含一个线性方程。创建这两种树时,我们将尽量使得代码之间可以重用。下面先给出两种树构建算法中的一些共用代码。

复制代码
 1 from numpy import * 2  3 def loadDataSet(fileName): 4     ''' 5     读取一个一tab键为分隔符的文件,然后将每行的内容保存成一组浮点数     6     ''' 7     dataMat = [] 8     fr = open(fileName) 9     for line in fr.readlines():10         curLine = line.strip().split('\t')11         fltLine = map(float,curLine)12         dataMat.append(fltLine)13     return dataMat14 15 def binSplitDataSet(dataSet, feature, value):16     '''17     数据集切分函数    18     '''19     mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]20     mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]21     return mat0,mat122 23 def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):24     '''25     树构建函数26     leafType:建立叶节点的函数27     errType:误差计算函数28     ops:包含树构建所需其他参数的元组    29     '''    30     #选择最优的划分特征31     #如果满足停止条件,将返回None和某类模型的值32     #若构建的是回归树,该模型是一个常数;如果是模型树,其模型是一个线性方程33     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)34     if feat == None: return val #35     retTree = {}36     retTree['spInd'] = feat37     retTree['spVal'] = val38     #将数据集分为两份,之后递归调用继续划分39     lSet, rSet = binSplitDataSet(dataSet, feat, val)40     retTree['left'] = createTree(lSet, leafType, errType, ops)41     retTree['right'] = createTree(rSet, leafType, errType, ops)42     return retTree  
复制代码

 

回到顶部

2、CART回归树

CART(Classification And Regression Trees, 分类回归树)是十分著名的树构建算法,它使用二元切分来处理连续性变量,对其稍作修改就可处理回归问题。

2.1 构建树

①切分数据集并生成叶节点

给定某个误差计算方法,chooseBestSplit()函数会找到数据集上最佳的二元切分方式,此外,该函数还要确定什么时候停止切分,一旦停止切分会生成一个叶节点。该函数伪代码大致如下:

②计算误差

这里采用计算数据的平方误差。

Python代码:

复制代码
 1 def regLeaf(dataSet): 2     '''负责生成叶节点''' 3     #当chooseBestSplit()函数确定不再对数据进行切分时,将调用本函数来得到叶节点的模型。 4     #在回归树中,该模型其实就是目标变量的均值。 5     return mean(dataSet[:,-1]) 6  7 def regErr(dataSet): 8     ''' 9     误差估计函数,该函数在给定的数据上计算目标变量的平方误差,这里直接调用均方差函数10     '''11     return var(dataSet[:,-1]) * shape(dataSet)[0]#返回总方差12 13 def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):14     '''15     用最佳方式切分数据集和生成相应的叶节点16     '''  17     #ops为用户指定参数,用于控制函数的停止时机18     tolS = ops[0]; tolN = ops[1]19     #如果所有值相等则退出20     if len(set(dataSet[:,-1].T.tolist()[0])) == 1:21         return None, leafType(dataSet)22     m,n = shape(dataSet)23     S = errType(dataSet)24     bestS = inf; bestIndex = 0; bestValue = 025     #在所有可能的特征及其可能取值上遍历,找到最佳的切分方式26     #最佳切分也就是使得切分后能达到最低误差的切分27     for featIndex in range(n-1):28         for splitVal in set(dataSet[:,featIndex]):29             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)30             if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue31             newS = errType(mat0) + errType(mat1)32             if newS < bestS: 33                 bestIndex = featIndex34                 bestValue = splitVal35                 bestS = newS36     #如果误差减小不大则退出37     if (S - bestS) < tolS: 38         return None, leafType(dataSet)39     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)40     #如果切分出的数据集很小则退出41     if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):42         return None, leafType(dataSet)43     #提前终止条件都不满足,返回切分特征和特征值44     return bestIndex,bestValue
复制代码

主要测试命令:

>>> reload(regTrees)>>> myData = regTrees.loadDataSet('ex00.txt')>>> myMat = mat(myData)>>> regTrees.createTree(myMat)

【注意】本代码在Python3.5环境下测试未通过,错误发生在以上第5行-->return mean(dataSet[:,-1])

错误类型为 TypeError: unsupported operand type(s) for /: 'map' and 'int' 暂未找到解决办法。所以,以下测试结果均来自书本。

 

2.2 剪枝

一棵树如果节点过多,表明该模型可能对数据进行了“过拟合”。通过降低决策树的复杂度来避免过拟合的过程称为剪枝(pruning) 。

①预剪枝

在函数chooseBestSplit()中的提前终止条件,实际上是在进行一种所谓的预剪枝(prepruning)操作。树构建算法其实对输人的参数tols和tolN非常敏感,如果使用其他值将不太容易达到这么好的效果。 

②后剪枝

使用后剪枝方法需要将数据集分成测试集和训练集。首先指定参数,使得构建出的树足够大、足够复杂,便于剪枝。接下来从上而下找到叶节点,用测试集来判断将这些叶节点合并是否能降低测试误差。如果是的话就合并 。

Python实现代码:

复制代码
 1 def prune(tree, testData): 2     '''回归树剪枝函数''' 3     if shape(testData)[0] == 0: return getMean(tree) #无测试数据则返回树的平均值 4     if (isTree(tree['right']) or isTree(tree['left'])):# 5         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) 6     if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet) 7     if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet) 8     #如果两个分支已经不再是子树,合并它们 9     #具体做法是对合并前后的误差进行比较。如果合并后的误差比不合并的误差小就进行合并操作,反之则不合并直接返回10     if not isTree(tree['left']) and not isTree(tree['right']):11         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])12         errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\13             sum(power(rSet[:,-1] - tree['right'],2))14         treeMean = (tree['left']+tree['right'])/2.015         errorMerge = sum(power(testData[:,-1] - treeMean,2))16         if errorMerge < errorNoMerge: 17             print("merging")18             return treeMean19         else: return tree20 21 def isTree(obj):22     '''判断输入变量是否是一棵树'''23     return (type(obj).__name__=='dict')24 25 def getMean(tree):26     '''从上往下遍历树直到叶节点为止,计算它们的平均值'''27     if isTree(tree['right']): tree['right'] = getMean(tree['right'])28     if isTree(tree['left']): tree['left'] = getMean(tree['left'])29     return (tree['left']+tree['right'])/2.0
复制代码

测试命令:

复制代码
reload(regTrees)myData2 = regTrees.loadDataSet('ex2.txt')myMat2 = mat(myData2)from numpy import *myMat2 = mat(myData2)regTrees.createTree(myMat2)myTree = regTrees.createTree(myMat2, ops=(0,1))myDataTest = regTrees.loadDataSet('ex2test.txt')myMat2Test = mat(myDataTest)regTrees.prune(myTree, myMat2Test)
复制代码

 

回到顶部

3、模型树

①叶节点

用树建模,除了把叶节点简单地设定为常数值外,还可把叶节点设定为分段线性函数,这里的分段线性是指模型由多个线性片段组成。

如下图所示数据,如果使用两条直线拟合是否比使用一组常数来建模好呢?答案显而易见。可以设计两条分别从0.0~0.3、从0.3~1.0的直线,于是就可以得到两个线性模型。因为数据集里的一部分数据(0.0~0.3)以某个线性模型建模,而另一部分数据(0.3~1.0)则以另一个线性模型建模,因此我们说采用了所谓的分段线性模型。

②误差计算

前面用于回归树的误差计算方法这里不能再用。稍加变化,对于给定的数据集,先用线性的模型来对它进行拟合,然后计算真实的目标值与模型预测值间的差值。最后将这些差值的平方求和就得到了所需的误差。 

与回归树不同,模型树Python代码有以下变化:

复制代码
 1 def linearSolve(dataSet): 2     '''将数据集格式化成目标变量Y和自变量X,X、Y用于执行简单线性回归''' 3     m,n = shape(dataSet) 4     X = mat(ones((m,n))); Y = mat(ones((m,1))) 5     X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#默认最后一列为Y 6     xTx = X.T*X 7     #若矩阵的逆不存在,抛异常 8     if linalg.det(xTx) == 0.0: 9         raise NameError('This matrix is singular, cannot do inverse,\n\10         try increasing the second value of ops')11     ws = xTx.I * (X.T * Y)#回归系数12     return ws,X,Y13 14 def modelLeaf(dataSet):15     '''负责生成叶节点模型'''16     ws,X,Y = linearSolve(dataSet)17     return ws18 19 def modelErr(dataSet):20     '''误差计算函数'''21     ws,X,Y = linearSolve(dataSet)22     yHat = X * ws23     return sum(power(Y - yHat,2))
复制代码

测试命令:

>>> regTrees.createTree(myMat,regTrees.modelLeaf,regTrees.modelErr.(1,10))
回到顶部

4、实例:树回归与标准回归的比较

前面介绍了模型树、回归树和一般的回归方法,下面测试一下哪个模型最好。这些模型将在某个数据上进行测试,该数据涉及人的智力水平和自行车的速度的关系。

复制代码
 1 def createForeCast(tree, testData, modelEval=regTreeEval): 2     # 多次调用treeForeCast()函数,以向量形式返回预测值,在整个测试集进行预测非常有用 3     m=len(testData) 4     yHat = mat(zeros((m,1))) 5     for i in range(m): 6         yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval) 7     return yHat 8  9 def treeForeCast(tree, inData, modelEval=regTreeEval):10     '''11     # 在给定树结构的情况下,对于单个数据点,该函数会给出一个预测值。12     # modeEval是对叶节点进行预测的函数引用,指定树的类型,以便在叶节点上调用合适的模型。13     # 此函数自顶向下遍历整棵树,直到命中叶节点为止,一旦到达叶节点,它就会在输入数据上14     # 调用modelEval()函数,该函数的默认值为regTreeEval()    15     '''16     if not isTree(tree): return modelEval(tree, inData)17     if inData[tree['spInd']] > tree['spVal']:18         if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)19         else: return modelEval(tree['left'], inData)20     else:21         if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)22         else: return modelEval(tree['right'], inData)23 24 def regTreeEval(model, inDat):25     #为了和modeTreeEval()保持一致,保留两个输入参数26     return float(model)27 28 def modelTreeEval(model, inDat):29     #对输入数据进行格式化处理,在原数据矩阵上增加第0列,元素的值都是130     n = shape(inDat)[1]31     X = mat(ones((1,n+1)))32     X[:,1:n+1]=inDat33     return float(X*model)
复制代码

测试命令:

复制代码
#回归树>>> reload(regTrees)>>> trainMat = mat(regTrees.loadDataSet('bikeSpeedVsIq_train.txt'))>>> testMat = mat(regTrees.loadDataSet('bikeSpeedVsIq_test.txt'))>>> myTree = regTrees.createTree(trainMat, ops=(1,20))>>> yHat = regTrees.createForeCast(myTree, testMat[:,0])>>> corrcoef(yHat, testMat[:,1], rowvar=0)array([[ 1.        ,  0.96408523],       [ 0.96408523,  1.        ]])#模型树>>> myTree = regTrees.createTree(trainMat, regTrees.modelLeaf, regTrees.modelErr, (1,20))>>> yHat = regTrees.createForeCast(myTree, testMat[:,0], regTrees.modelTreeEval)>>> corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]0.97604121913806285# 标准回归>>> ws, X, Y = regTrees.linearSolve(trainMat)>>> wsmatrix([[ 37.58916794],        [  6.18978355]])>>> for i in range(shape(testMat)[0]) :...     yHat[i] = testMat[i,0]*ws[1,0] + ws[0,0]...>>> corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]0.94346842356747584
复制代码

 

THE END.

回归算法原理

CART(Classification and Regression Tree)算法是目前决策树算法中最为成熟的一类算法,应用范围也比较广泛。它既可以用于分类。 
西方预测理论一般都是基于回归的,CART是一种通过决策树方法实现回归的算法,它具有很多其他全局回归算法不具有的特性。 
在创建回归模型时,样本的取值分为观察值和输出值两种,观察值和输出值都是连续的,不像分类函数那样有分类标签,只有根据数据集的数据特征来创建一个预测的模型,反映曲线的变化趋势。在这种情况下,原有分类树的最优划分规则就不再起作用了。在预测中,CART使用最小剩余方差(Squared Residuals Minimization)来判定回归树的最优划分,这个准则期望划分之后的子树与样本点的误差方差最小。这样决策树将数据集划分成很多子模型数据,然后利用线性回归技术来建模。如果每次切分后的数据子集仍然难以拟合,就继续切分。在这种切分方式下创建出的预测树,每个叶子节点都是一个线性回归模型。这些线性回归模型反映了样本集合(观测集合)中蕴含的模式,也被称为模型树。因此,CART不仅支持正体预测,也支持局部模式的预测,并有能力从整体中找到模式,或根据模式组合成一个整体。整体与模式之间的相互结合,对于预测分析有重要价值。因此CART决策树算法在预测中的应用非常广泛。 
下面介绍CART的算法流程: 
(1)决策树主函数:决策树的主函数是一个递归函数。该函数的主要功能是按照CART的规则生长出决策树的每个分支节点,并根据终止条件结束算法。 
a.输入需要分类的数据集和类别标签。 
b.使用最小剩余方差判定回归树的最优划分,并创建特征的划分节点——最小剩余方差子函数。 
c.在划分节点划分数据集为两部分——二分数据集子函数。 
d.根据二分数据的结果构建出新的左右节点,作为树生长出的两个分支。 
e.检验是否符合递归的终止条件。 
f.将划分的新节点包含的数据集和类别标签作为输入,递归执行上述步骤。 
(2)使用最小剩余方差子函数,计算数据集各列的最优划分方差、划分列、划分值 
(3)二分数据集:根据给定的分隔列和分隔值将数据集一分为二,分别返回。

最小剩余方差法

在回归树中,数据集均为连续性。连续数据的处理方法与离散数据不同,离散数据是按每个特征的取值划分,而连续特征则要计算出最优划分点。但在连续数据集上计算线性相关度非常简单,算法思想来源于最小二乘法。 
最小剩余方差法,首先求取划分数据列的均值和总方差。总方差的计算方法有两种 
求取均值std,计算每个数据点与std的方差,然后将n个点求和。 
求取方差var,然后var_sum = var*n,n为数据集数据数目。 
那么,每次最佳分支特征的选取过程如下。 
(1)先令最佳方差为无限大 bestVar = inf。 
(2)此次遍历所有特征列及每个特征列的所有样本点(这是一个二循环),在每个样本点上二分数据集。 
(3)计算二分数据集后的总方差currentVar,如果currentVar < bestVar,则bestVar = currentVar。 
返回计算的最优分支特征列、分支特征值(连续特征则为划分点的值)以及左右分支子数据集到主程序。

模型树

使用CART进行预测是把叶子节点设定为一系列的分段线性函数,这些分段线性函数是对源数据曲线的一种模拟,每个线性函数都被称为一颗模型树。模型树具有很多优秀的性质,它包含了如下特征。 
一般而言,样本总体的重复性不会很高,但局部模式经常重复,也就是所说的历史不会简单的重复,但会重演。模型比总体对未来的预测而言更有用。 
模型给出了数据的范围,它可能是一个时间范围,也可能是一个空间范围;而且模型还给出了变化的趋势,可以是曲线,也可以是直线,这依赖于使用的回归算法。这些因素使模型具有很强的可解释性。 
传统的回归方法,无论是线性回归还是非线性回归,都不如模型树包含的信息丰富,因此模型树具有更高的预测准确度。

Scikit-Learn实现

#!/usr/bin/python# created by lixin 20161118import numpy as npfrom numpy import *from sklearn.tree import DecisionTreeRegressorimport matplotlib.pyplot as pltdef plotfigure(X,X_test,y,yp):        plt.figure()        plt.scatter(X,y,c="k",label="data")        plt.plot(X_test,yp,c="r",label="max_depth=5",linewidth=2)        plt.xlabel("data")        plt.ylabel("target")        plt.title("Decision Tree Regression")        plt.legend(loc='upper right')        plt.show()        #plt.savefig('./res.png', format='png')x = np.linspace(-5,5,200)siny = np.sin(x)X = mat(x).Ty = siny + np.random.rand(1,len(siny))*1.5y = y.tolist()[0]clf = DecisionTreeRegressor(max_depth=4)clf.fit(X,y)X_test = np.arange(-5.0,5.0,0.05)[:,np.newaxiyp = clf.predict(X_test)plotfigure(X,X_test,y,yp)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

图1


原创粉丝点击