CART回归树&模型树 生成 剪枝 in Python

来源:互联网 发布:aliasstudio软件下载 编辑:程序博客网 时间:2024/05/22 02:19

现实中,数据集中经常包含一些复杂的相互关系,使得输入数据和目标变量之间呈现非线性关系。对这些复杂的关系建模,一种可行的方式是使用树来对预测值进行分段,包括分段常数或者分段直线,即通过树结构对数据进行切分后,在叶节点上,对叶节点上的数据,取均值构造回归树,或者取线性模型构造模型树。

下面,我们统一将基于CART的回归树和模型树称作树回归。


1、树回归的特点


1.1 相对之前提到的ID3决策树来说,基于二元切分的树回归切分不会过快,而且可以处理连续性特征数据。

1.2 优点:可以对复杂和非线性数据建模

1.3 缺点:结果不像线性回归那么好理解

1.4 模型树可解释性由于回归树,相对而言,模型树也具有更高的预测准确度。


2、各种回归方法的比较


对于模型树、回归树和之前的线性回归,一种比较客观的比较方法是计算相关系数,即R^2值。

只需调用Numpy库中的命令corrcoef(yHat,y,rowvar=0)即可,其中yHat为模型预测值,y是目标变量的实际值。

R^2值越接近1.0说明预测性能越好。


3、几个主要函数伪代码


3.1 确定数据集切分的最佳位置 chooseBestSplit()函数

如果数据集中目标变量只有一种:不进行后续切分,直接将此数据集构建为叶节点对每个特征:对每个特征值:将数据集切分为两份计算切分后两个子数据集的误差和如果此误差和小于当前最小误差:将当前切分设定为最佳切分并更新最小误差如果数据集上的误差和当前最小误差之间没有达到设定的容许误差下降值:不进行后续切分,直接将此数据集构建为叶节点如果切分后的子数据集中的样本数低于设定的最少样本数:不进行后续切分,直接将此数据集构建为叶节点返回记录的最佳切分的特征和切分点


3.2 树的生成算法 createTree()函数

调用chooseBestSplit()找到最佳待切分特征:如果该节点不能再分,即待切分特征无:将该节点存为叶节点执行二元切分在右子树调用createTree()函数在左子树调用createTree()函数

3.3 后剪枝算法 prune()函数

实际上,chooseBestSplit()函数中的三个如果已经对树生成过程进行了预剪枝,但此操作与算法设定的停止条件相关,不太好操作,还是要使用测试数据集进行后剪枝。

基于前面所得的树对测试数据进行切分:如果存在任一子集不是叶节点而是树:在该子集上调用prune()函数计算此时标准二分树的误差:即两个子叶节点上的误差和计算将当前两个叶节点合并后的误差:即当前标准二分树的根节点值取两叶节点均值后构成的单节点结构的误差如果合并后降低误差的话,就将此两叶节点进行合并


4、Python实现

from numpy import *def loadDataSet(fileName):# creat a list, but following dataSet represents matrixdataMat = []fr = open(fileName)for line in fr.readlines():currLine = line.strip().split('\t')fltLine = map(float, currLine)dataMat.append(fltLine)return dataMat### Preparing for creating tree# the function regleaf and modelleaf is going to create the leafnodesdef regLeaf(dataSet):return mean(dataSet[:,-1])def regErr(dataSet):return shape(dataSet)[0] * var(dataSet[:,-1])# used for measuring the uniformity of data# or say for calculating the chaos of data def linearSolve(dataSet):N, n = shape(dataSet)X = mat(ones((N,n)))Y = mat(ones((N,1)))X[:, 1:n] = dataSet[:, 0:n-1]Y = dataSet[:, -1]xTx = X.T * Xif linalg.det(xTx) == 0.0 :raise NameError('This matrix is singular, cannot do inverse, \n\try increasing the second value of ops')ws = xTx.I * (X.T * Y)return ws, X, Ydef modelLeaf(dataSet):ws, X, Y = linearSolve(dataSet)return wsdef modelErr(dataSet):ws, X, Y = linearSolve(dataSet)yHat = X * wsreturn sum(power(Y-yHat, 2))### Creating treedef binSplitDataSet(dataSet, feature, value):mat0 = dataSet[nonzero(dataSet[:,feature] >value)[0], :]mat1 = dataSet[nonzero(dataSet[:,feature]<=value)[0], :]return mat0, mat1def chooseBestSplit(dataSet, leafType, errType, ops):tolS = ops[0]# desent error value toleratedtolN = ops[1]# minimum number of samples splitedif len(set(dataSet[:,-1].T.tolist()[0])) == 1 :return None, leafType(dataSet)N, n = shape(dataSet)S = errType(dataSet)bestS = inf; bestIndex = 0; bestValue = 0for featIndex in range(n-1) :for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)if (shape(mat0)[0]<tolN) or (shape(mat1)<tolN):continuenewS = errType(mat0) + errType(mat1)if newS < bestS:bestIndex = featIndexbestValue = splitValbestS = newSif (S-bestS) < tolS :return None, leafType(dataSet)mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN) :return None, leafType(dataSet)return bestIndex, bestValuedef createTree(dataSet, leafType, errType, ops):feat, val = chooseBestSplit(dataSet, leafType, errType, ops)if feat == None:return valretTree = {}retTree['spInd'] = featretTree['spVal'] = vallSet, rSet = binSplitDataSet(dataSet, feat, val)retTree['left'] = createTree(lSet, leafType, errType, ops)retTree['right'] = createTree(rSet, leafType, errType, ops)return retTree### Post Purningdef isTree(obj):return (type(obj).__name__ == 'dict')def getMean(tree):if isTree(tree['right']): tree['right'] = getMean(tree['right'])if isTree(tree['left']): tree['left'] = getMean(tree['left'])return (tree['right']+tree['left'])/2.0def postPurning(tree, testData):if shape(testData)[0] == 0 :return getMean(tree)if isTree(tree['right']) or isTree(tree['left']) :lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])if isTree(tree['left']) :tree['left'] = postPurning(tree['left'], lSet)if isTree(tree['right']) :tree['right'] = postPurning(tree['right'], rSet)if not isTree(tree['left']) and not isTree(tree['right']) :lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])errNoMerge = sum(power(lSet[:,-1]-tree['left'], 2)) + sum(power(rSet[:,-1]-tree['right'], 2))treeMean = (tree['left']+tree['right']) / 2.0errMerge = sum(power(testData[:,-1]-treeMean, 2))if errMerge < errNoMerge :print "merging"return treeMeanelse:return treeelse:return tree### Predictingdef regTreeEval(model, inData):return float(model)def modelTreeEval(model, inData):n = shape(inData)[1]X = mat(zeros((1,n+1)))X[:,1:n+1] = inDatareturn float(X*model)def treeForecast(tree, inData, treeEval):if not isTree(tree):return treeEval(tree, inData)if inData[tree['spInd']] > tree['spVal'] :if isTree(tree['left']) :return treeForecast(tree['left'], inData, treeEval):else:return treeEval(tree['left'], inData)else:if isTree(tree['right']) :return treeForecast(tree['right'], inData, treeEval):else:return treeEval(tree['right'], inData)def createForecast(tree, testData, treeEval):M = len(testData)yHat = mat(zeors(M,1))for ii in range(M):yHat[ii,0] = treeForecast(tree, mat(testData[ii]), treeEval)return yHat



8 0
原创粉丝点击