树回归源码分析(1)

来源:互联网 发布:wifi速度测试软件 编辑:程序博客网 时间:2024/06/05 11:43

线性回归包含了强大的方法,但是需要拟合所有的数据集(局部加权线性回归除外),但是当数据特征复杂时,构建全局模型就难了,况且实际生活很多问题都是非线性的,不可能使用全局线性模型来拟合所有的数据。
现有可以将数据集切分成很多易建模的数据,然后再利用线性回归技术建模,以就得到了CART——Classification And Regression Tree(分类回归树)的树构建算法,该算法既可以用于分类还可以用于回归。

二元切分法:即每次把数据集切成两份。如果数据的某特征值等于切分所要求的值,那么这些数据就进人树的左子树,反之则进人树的右子树。

使用二元切分法则易于对树构建过程进行调整以处理连续型特征。具体的处理方法是:如果特征值大于给定值就走左子树,否则就走右子树。

CART算法只做二元切分,所以这里可以固定树的数据结构。树包含左键和右键,可以存储另一棵子树或者单个值。字典还包含特征和特征值这两个键,它们给出切分算法所有的特征和特征值。

1. CART算法用于回归

几点要明确的地方:

  • ID3算法会在给定节点时计算数据的混乱度,而连续数值的混乱度的度量用方差(平方误差的均值),这里用总方差(平方误差的总值),总方差可以通过均方差乘以数据集中的样本点的个数来得到。
  • 源代码有错误的地方,解决方法:源代码错误修正
# -*- coding: utf-8 -*-"""Created on Fri Nov 03 10:35:00 2017"""from numpy import *# 加载数据函数def loadDataSet(fileName):         dataMat = []                    fr = open(fileName)    for line in fr.readlines():        curLine = line.strip().split('\t')  # 读取以tab键为分割符的文件        fltLine = map(float,curLine)   # 将每行映射为浮点数        dataMat.append(fltLine)  # 把所有的数据保存到一起    return dataMat# 二元切分数据集def binSplitDataSet(dataSet, feature, value): # 三个参数:数据集合,待切分的特征,和该特征的某个值    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:] # 数组过滤,mat0是特征数列中大于value的所有样本行    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:] # 得到和feature相对应的满足要求的样本    return mat0,mat1 # 返回两个子集,分别是针对某特征列划分的不同样本集# 生成叶节点    def regLeaf(dataSet):      return mean(dataSet[:,-1])  # 在回归树种返回目标变量的均值# 误差估计函数,计算连续值的混乱度def regErr(dataSet):   # var()均方差函数,要返回总方差,所以要用均方差乘以数据集中的样本个数    return var(dataSet[:,-1]) * shape(dataSet)[0] # 用最佳方式切分数据集和生成相应的叶节点。leafType,errType是对函数的引用def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):    tolS = ops[0]; tolN = ops[1] # tolS容许的误差下降值,tolN切分的最少样本数    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: # 如果特征数目只剩一个,就不再切分,直接返回        print 'back from here 1 ..'        return None, leafType(dataSet)    m,n = shape(dataSet) # 当前数据集的大小    S = errType(dataSet) # 计算误差,s用于和新切分误差对比    bestS = inf; bestIndex = 0; bestValue = 0     for featIndex in range(n-1):  # 遍历所有的特征,除了最后一个        for splitVal in set(dataSet[:,featIndex].T.A.tolist()[0]): # 针对每个特征,在所有样本中查看不同的特征值                 mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) # 将数据集切分两份                        if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):   # 切分的最少样本数                               continue              newS = errType(mat0) + errType(mat1)  # 计算切分的总方差                           if newS < bestS:                  print 'featIndex,splitVal:',featIndex,'and', splitVal                 bestIndex = featIndex  # 如果新的总方差小于当前的方差,则返回特征索引和切分特征值                 bestValue = splitVal                 bestS = newS                     if (S - bestS) < tolS:   # 如果容错的误差下降值变化不大,就停止切分,直接创造叶节点        print 'back from here 2 ..'        return None, leafType(dataSet)       mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  # 如果切分的数据集很小则退出直接创造叶节点        print 'back from here 3 ..'        return None, leafType(dataSet)    return bestIndex,bestValue  # 如果所有的提前终止条件都不满足,就返回切分特征和特征值# 找到数据的最佳二元切分方式def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): # ops是一个包含树构建所需的参数元组    # 把数据集分成两部分,如果满足停止条件返回None和某类模型的值    # 满足停止条件:feat是None,val是某类模型的值    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)    print 'feat, val :',feat, 'and',val    if feat == None:        print 'back creatTree..'        return val # 回归树:模型是常数,模型树:模型是线性方程     retTree = {}     retTree['spInd'] = feat    retTree['spVal'] = val    lSet, rSet = binSplitDataSet(dataSet, feat, val) # 不满足停止条件时,lSet, rSet是两个数据集    retTree['left'] = createTree(lSet, leafType, errType, ops) # 递归调用createTree()函数    retTree['right'] = createTree(rSet, leafType, errType, ops)    return retTree  # 主函数testMat=mat(eye(4))mat0,mat1=binSplitDataSet(testMat,1,0.5)print 'mat0:',mat0print 'mat1:',mat1myDat=loadDataSet('ex00.txt')myMat=mat(myDat)print createTree(myMat)

运行结果:

mat0: [[ 0.  1.  0.  0.]]mat1: [[ 1.  0.  0.  0.] [ 0.  0.  1.  0.] [ 0.  0.  0.  1.]]featIndex,splitVal: 0 and 0.302001featIndex,splitVal: 0 and 0.55299featIndex,splitVal: 0 and 0.378595featIndex,splitVal: 0 and 0.406649featIndex,splitVal: 0 and 0.475976featIndex,splitVal: 0 and 0.48813feat, val : 0 and 0.48813featIndex,splitVal: 0 and 0.936783featIndex,splitVal: 0 and 0.727098featIndex,splitVal: 0 and 0.72312featIndex,splitVal: 0 and 0.645762featIndex,splitVal: 0 and 0.648675featIndex,splitVal: 0 and 0.625336featIndex,splitVal: 0 and 0.622398featIndex,splitVal: 0 and 0.620599back from here 2 ..feat, val : None and 1.01809676724back creatTree..featIndex,splitVal: 0 and 0.302001featIndex,splitVal: 0 and 0.347837featIndex,splitVal: 0 and 0.346986featIndex,splitVal: 0 and 0.188218featIndex,splitVal: 0 and 0.048014featIndex,splitVal: 0 and 0.343479back from here 2 ..feat, val : None and -0.0446502857143back creatTree..{'spInd': 0, 'spVal': 0.48813, 'right': -0.044650285714285719, 'left': 1.0180967672413792}

可以由运行结果看出代码的具体运行过程:

  • 叶节点是相应的目标数据集的均值
  • 注意几个切分停止得条件和返回叶节点

2. 树剪枝

通过降低决策树的复杂度来避免过拟合的过程称为剪枝,在上面的提前终止条件实际是一种预剪枝的操作。另一种是使用测试集和训练集,称为后剪枝。

  • 树构建算法其实对输入的tolS和tolN非常敏感,也就是对提前终止的人为输入参数,其中tolS对误差的数量级十分敏感,所以需要我们手动调节参数,但是通过不断修改停止条件来得到合理的结果并不是很好的办法,甚至有时候我们不确定到底我们需要什么样的结果,于是有了通过测试集来对树进行剪枝,也就避免了用户指定参数。

后剪枝

函数prune()的伪代码如下:

基于已有的树切分测试数据:

  • 如果存在任一子集是一棵树,则在该子集递归剪枝过程
  • 计算将当前两个叶节点合并后的误差
  • 计算不合并的误差
  • 如果合并会降低误差的话,就将叶节点合并
# -*- coding: utf-8 -*-"""Created on Fri Nov 03 10:35:00 2017"""from numpy import *# 加载数据函数def loadDataSet(fileName):         dataMat = []                    fr = open(fileName)    for line in fr.readlines():        curLine = line.strip().split('\t')  # 读取以tab键为分割符的文件        fltLine = map(float,curLine)   # 将每行映射为浮点数        dataMat.append(fltLine)  # 把所有的数据保存到一起    return dataMat# 二元切分数据集def binSplitDataSet(dataSet, feature, value): # 三个参数:数据集合,待切分的特征,和该特征的某个值    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:] # 数组过滤,mat0是特征数列中大于value的所有样本行    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:] # 得到和feature相对应的满足要求的样本    return mat0,mat1 # 返回两个子集,分别是针对某特征列划分的不同样本集# 生成叶节点    def regLeaf(dataSet):      return mean(dataSet[:,-1])  # 在回归树种返回目标变量的均值# 误差估计函数,计算连续值的混乱度def regErr(dataSet):   # var()均方差函数,要返回总方差,所以要用均方差乘以数据集中的样本个数    return var(dataSet[:,-1]) * shape(dataSet)[0] # 用最佳方式切分数据集和生成相应的叶节点。leafType,errType是对函数的引用def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):    tolS = ops[0]; tolN = ops[1] # tolS容许的误差下降值,tolN切分的最少样本数    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: # 如果特征数目只剩一个,就不再切分,直接返回        print 'back from here 1 ..'        return None, leafType(dataSet)    m,n = shape(dataSet) # 当前数据集的大小    S = errType(dataSet) # 计算误差,s用于和新切分误差对比    bestS = inf; bestIndex = 0; bestValue = 0     for featIndex in range(n-1):  # 遍历所有的特征,除了最后一个        for splitVal in set(dataSet[:,featIndex].T.A.tolist()[0]): # 针对每个特征,在所有样本中查看不同的特征值                 mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) # 将数据集切分两份                        if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):   # 切分的最少样本数                               continue              newS = errType(mat0) + errType(mat1)  # 计算切分的总方差                           if newS < bestS:                  print 'featIndex,splitVal,newS:',featIndex,'and', splitVal,'and', newS                 bestIndex = featIndex  # 如果新的总方差小于当前的方差,则返回特征索引和切分特征值                 bestValue = splitVal                 bestS = newS                     if (S - bestS) < tolS:   # 如果容错的误差下降值变化不大,就停止切分,直接创造叶节点        print 'back from here 2 ..'        return None, leafType(dataSet)       mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  # 如果切分的数据集很小则退出直接创造叶节点        print 'back from here 3 ..'        return None, leafType(dataSet)    return bestIndex,bestValue  # 如果所有的提前终止条件都不满足,就返回切分特征和特征值# 找到数据的最佳二元切分方式def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): # ops是一个包含树构建所需的参数元组    # 把数据集分成两部分,如果满足停止条件返回None和某类模型的值    # 满足停止条件:feat是None,val是某类模型的值    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)    print 'feat, val :',feat, 'and',val    if feat == None:        print 'back creatTree..'        return val # 回归树:模型是常数,模型树:模型是线性方程     retTree = {}     retTree['spInd'] = feat    retTree['spVal'] = val    lSet, rSet = binSplitDataSet(dataSet, feat, val) # 不满足停止条件时,lSet, rSet是两个数据集    retTree['left'] = createTree(lSet, leafType, errType, ops) # 递归调用createTree()函数    retTree['right'] = createTree(rSet, leafType, errType, ops)    return retTree  # 回归树剪枝函数def isTree(obj): # 测试一个输入变量是否是树类型    return (type(obj).__name__=='dict')  # 判断当前处理的节点是否是叶节点def getMean(tree): # 叶节点处理函数,是一个递归函数,从上往下遍历树直到叶节点为止    if isTree(tree['right']):         tree['right'] = getMean(tree['right'])    if isTree(tree['left']):         tree['left'] = getMean(tree['left'])    return (tree['left']+tree['right'])/2.0   # 返回整个树的平均值def prune(tree, testData): # testData待测试的数据,tree是由训练集生成的    if shape(testData)[0] == 0: # 如果没有测试数据,则对树进行塌陷处理        return getMean(tree)      if (isTree(tree['right']) or isTree(tree['left'])): # 判断当前分支是否是树        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])    if isTree(tree['left']): # 左树剪枝        tree['left'] = prune(tree['left'], lSet) # 反复调用prune()对测试数据进行切分    if isTree(tree['right']): # 右树剪枝        tree['right'] =  prune(tree['right'], rSet)    # 如果左右两个不再是子树,就进行合并    if not isTree(tree['left']) and not isTree(tree['right']):        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) \                          +sum(power(rSet[:,-1] - tree['right'],2)) # 此处用的是平方误差        treeMean = (tree['left']+tree['right'])/2.0        errorMerge = sum(power(testData[:,-1] - treeMean,2))        if errorMerge < errorNoMerge:  # 比较剪枝前后的误差变化            print "merging"            return treeMean        else: return tree    else: return tree# 主函数myDat2=loadDataSet('ex2.txt')myMat2=mat(myDat2)myTree=createTree(myMat2,ops=(0,1)) # createTree(myMat)返回的值是dict类型的myDatTest=loadDataSet('ex2test.txt')myMat2Test=mat(myDatTest)print '..............................'  print prune(myTree,myMat2Test)             

运行结果:

mergingmergingmergingmergingmergingmergingmergingmergingmergingmergingmergingmergingmerging...'spVal': 0.965969, 'right': {'spInd': 0, 'spVal': 0.956951, 'right': 111.2013225, 'left': {'spInd': 0, 'spVal': 0.958512, 'right': 135.83701300000001, 'left': {'spInd': 0, 'spVal': 0.960398, 'right': 123.559747, 'left': 112.386764}}}, 'left': 92.523991499999994}}}}

可以看出大量的节点已经被剪枝掉了,虽然看着还是那么多的节点,但是确实已经减少了很多了,一般情况下为了寻求最佳模型可以同时使用预剪枝和后剪枝两种技术。

注意:

  • 其中的塌陷处理,自上而下的遍历树到叶节点为止,如果找到两个叶节点则计算它们的平均值,返回整个树的平均值。
  • 注意其中的递归调用剪枝处理,对数据结构的理解有所要求。

3. 模型树

简单来说就是把原来的叶节点由常数值变为分段线性函数,所谓的分段线性就是指模型由多个线性片段组成。也就是在某些情况下,分段线性要比很多节点组成的一颗大树更容易解释。

  • 模型树的可解释性优于回归树的,另外模型树也具有更高的预测准确度。
  • 前面用于回归树的误差计算方法这里不能再用。稍加变化,对于给定的数据集,应该先用线性的模型来对它进行拟合,然后计算真实的目标值与模型预测值间的差值。最后将这些差值的平方求和就得到了所需的误差

在CART算法用于回归代码中加入下面的函数,并且把主函数改为如下:

# 模型树的叶节点生成函数def linearSolve(dataSet):       m,n = shape(dataSet)    X = mat(ones((m,n))); Y = mat(ones((m,1)))      X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1] # 将数据集格式化成目标变量和自变量     xTx = X.T*X    if linalg.det(xTx) == 0.0: # 这个矩阵是奇异的,不能求逆        raise NameError('This matrix is singular, cannot do inverse,\n\        try increasing the second value of ops')    ws = xTx.I * (X.T * Y) # 线性回归的系数    return ws,X,Ydef modelLeaf(dataSet):  # 生成叶节点模型    ws,X,Y = linearSolve(dataSet)    return ws # 返回回归系数def modelErr(dataSet):    ws,X,Y = linearSolve(dataSet)    yHat = X * ws    return sum(power(Y - yHat,2)) # 在给定数据集上计算误差,返回平方误差# 主函数# 模型树myMat2=mat(loadDataSet('exp2.txt'))modelTree=createTree(myMat2, modelLeaf,modelErr,(1,10))print '模型树:',modelTree

运行结果:

模型树: {'spInd': 0, 'spVal': 0.285477, 'right': matrix([[ 3.46877936],        [ 1.18521743]]), 'left': matrix([[  1.69855694e-03],        [  1.19647739e+01]])}

这里写图片描述

由运行的结果可以看出:
分段线性生成的模型:
y=3.468+1.18521743x
y=0.0016985+11.96477x
而数据是由模型:
y=3.5+1.0x
y=0.0+12x再加上高斯噪声生成的。

两个模型已经非常接近了。

4. 树回归的比较

模型树、回归树以及第8章里的其他模型,哪一种模型更好呢?一个比较客观的方法是计算相关系数,也称为R2值。该相关系数可以通过调用Numpy库中的命令corrcoef(yHat,y,rowvar)来求解。

# -*- coding: utf-8 -*-"""Created on Fri Nov 03 10:35:00 2017"""from numpy import *# 加载数据函数def loadDataSet(fileName):         dataMat = []                    fr = open(fileName)    for line in fr.readlines():        curLine = line.strip().split('\t')  # 读取以tab键为分割符的文件        fltLine = map(float,curLine)   # 将每行映射为浮点数        dataMat.append(fltLine)  # 把所有的数据保存到一起    return dataMat# 二元切分数据集def binSplitDataSet(dataSet, feature, value): # 三个参数:数据集合,待切分的特征,和该特征的某个值    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:] # 数组过滤,mat0是特征数列中大于value的所有样本行    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:] # 得到和feature相对应的满足要求的样本    return mat0,mat1 # 返回两个子集,分别是针对某特征列划分的不同样本集# 生成叶节点    def regLeaf(dataSet):      return mean(dataSet[:,-1])  # 在回归树种返回目标变量的均值# 误差估计函数,计算连续值的混乱度def regErr(dataSet):   # var()均方差函数,要返回总方差,所以要用均方差乘以数据集中的样本个数    return var(dataSet[:,-1]) * shape(dataSet)[0] # 用最佳方式切分数据集和生成相应的叶节点。leafType,errType是对函数的引用def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):    tolS = ops[0]; tolN = ops[1] # tolS容许的误差下降值,tolN切分的最少样本数    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: # 如果特征数目只剩一个,就不再切分,直接返回        print 'back from here 1 ..'        return None, leafType(dataSet)    m,n = shape(dataSet) # 当前数据集的大小    S = errType(dataSet) # 计算误差,s用于和新切分误差对比    bestS = inf; bestIndex = 0; bestValue = 0     for featIndex in range(n-1):  # 遍历所有的特征,除了最后一个        for splitVal in set(dataSet[:,featIndex].T.A.tolist()[0]): # 针对每个特征,在所有样本中查看不同的特征值                 mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) # 将数据集切分两份                        if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):   # 切分的最少样本数                               continue              newS = errType(mat0) + errType(mat1)  # 计算切分的总方差                           if newS < bestS:                  print 'featIndex,splitVal:',featIndex,'and', splitVal                 bestIndex = featIndex  # 如果新的总方差小于当前的方差,则返回特征索引和切分特征值                 bestValue = splitVal                 bestS = newS                     if (S - bestS) < tolS:   # 如果容错的误差下降值变化不大,就停止切分,直接创造叶节点        print 'back from here 2 ..'        return None, leafType(dataSet)       mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  # 如果切分的数据集很小则退出直接创造叶节点        print 'back from here 3 ..'        return None, leafType(dataSet)    return bestIndex,bestValue  # 如果所有的提前终止条件都不满足,就返回切分特征和特征值# 找到数据的最佳二元切分方式def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): # ops是一个包含树构建所需的参数元组    # 把数据集分成两部分,如果满足停止条件返回None和某类模型的值    # 满足停止条件:feat是None,val是某类模型的值    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)    print 'feat, val :',feat, 'and',val    if feat == None:        print 'back creatTree..'        return val # 回归树:模型是常数,模型树:模型是线性方程     retTree = {}     retTree['spInd'] = feat    retTree['spVal'] = val    lSet, rSet = binSplitDataSet(dataSet, feat, val) # 不满足停止条件时,lSet, rSet是两个数据集    retTree['left'] = createTree(lSet, leafType, errType, ops) # 递归调用createTree()函数    retTree['right'] = createTree(rSet, leafType, errType, ops)    return retTree  # 模型树的叶节点生成函数def linearSolve(dataSet):       m,n = shape(dataSet)    X = mat(ones((m,n))); Y = mat(ones((m,1)))      X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1] # 将数据集格式化成目标变量和自变量     xTx = X.T*X    if linalg.det(xTx) == 0.0: # 这个矩阵是奇异的,不能求逆        raise NameError('This matrix is singular, cannot do inverse,\n\        try increasing the second value of ops')    ws = xTx.I * (X.T * Y) # 线性回归的系数    return ws,X,Y# 生成叶节点模型,返回回归系数def modelLeaf(dataSet):      ws,X,Y = linearSolve(dataSet)    return ws # 在给定数据集上计算误差,返回平方误差def modelErr(dataSet):    ws,X,Y = linearSolve(dataSet)    yHat = X * ws    return sum(power(Y - yHat,2)) def isTree(obj):    return (type(obj).__name__=='dict')# 对回归树节点进行预测def regTreeEval(model, inDat):    return float(model) # 返回树预测的值# 对模型树节点预测def modelTreeEval(model, inDat):    n = shape(inDat)[1]  # 格式化处理    X = mat(ones((1,n+1))) # 在原数据矩阵上增加第0列    X[:,1:n+1]=inDat    return float(X*model)# 自顶向下的遍历整棵树,直到命中叶节点为止def treeForeCast(tree, inData, modelEval=regTreeEval):    if not isTree(tree): # 判断是否是树字典,如果不是就返回        return modelEval(tree, inData)    if inData[tree['spInd']] > tree['spVal']:        if isTree(tree['left']):             return treeForeCast(tree['left'], inData, modelEval)        else: return modelEval(tree['left'], inData)    else:        if isTree(tree['right']):             return treeForeCast(tree['right'], inData, modelEval)        else: return modelEval(tree['right'], inData)def createForeCast(tree, testData, modelEval=regTreeEval):    m=len(testData)    yHat = mat(zeros((m,1)))    for i in range(m):        yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)    return yHat# 利用训练数据构造回归树trainMat=mat(loadDataSet('bikeSpeedVsIq_train.txt'))testMat=mat(loadDataSet('bikeSpeedVsIq_test.txt'))myTree=createTree(trainMat,ops=(1,20))  # 利用训练数据构造回归树yHat=createForeCast(myTree, testMat[:,0])coefficient_regtree=corrcoef(yHat,testMat[:,1],rowvar=0)[0,1]# 利用训练数据构造模型树myTree=createTree(trainMat,modelLeaf,modelErr,ops=(1,20))  # 利用训练数据构造回归树yHat=createForeCast(myTree, testMat[:,0],modelTreeEval)coefficient_modeltree=corrcoef(yHat,testMat[:,1],rowvar=0)[0,1]print 'regtree coefficient:',coefficient_regtreeprint 'modeltree coefficient:',coefficient_modeltree

运行结果:

...regtree coefficient: 0.964085231822modeltree coefficient: 0.976041219138

我们知道,R2的值越接近1.0越好,所以从上面的结果可以看出模型树的结果比回归树的要好,而线性回归的效果还不如回归树。

原创粉丝点击