树回归

来源:互联网 发布:windows snmp v2c 编辑:程序博客网 时间:2024/05/21 06:19

1.简单介绍

        线性回归方法可以有效的拟合所有样本点(局部加权线性回归除外)。当数据拥有众多特征并且特征之间关系十分复杂时,构建全局模型的想法一个是困难一个是笨拙。此外,实际中很多问题为非线性的,例如常见到的分段函数,不可能用全局线性模型来进行拟合。

树回归将数据集切分成多份易建模的数据,然后利用线性回归进行建模和拟合。这里介绍较为经典的树回归CART(classification and regression trees,分类回归树)算法。

2.分类回归树基本流程

    构建树:

           1.找到[最佳待切分特征]

            2.若不能再切分,则将该节点存为[叶子节点]并返回

            3.按照最佳待切分特征将数据集切分成左右子树(这里为了方便,假设大于特征值则为左,小于则归为右)

            4.对左子树进行[构建树]

            5.对右子树进行[构建树]

   最佳待切分特征:

           1.遍历特征

               1.1遍历特征所有特征值

                    1.1.1计算按该特征值进行数据集切分的[误差]

           2.选择误差最小的特征及其相应值作为最佳待切分特征并返回

   基于回归树的预测:

           1.判断当前回归树是否为叶子节点,如果是则[预测],如果不是则执行2

            2.将测试数据相应特征上的特征值与当前回归树进行比较,如果测试数据特征值大,则判别当前回归树的左子树是否为叶子节点,如果不是叶子节点则进行[基于回归树的预测],如果是叶子节点,则[预测];反之,判别当前回归树的右子树是否为叶子节点,如果不是叶子节点则进行[基于回归树的预测],如果是叶子节点,则[预测]

3.分类回归树的实践说明

  

        误差、叶子节点和预测三者有相关的关联关系,一种相对简单的是误差采用的是y值均方差,叶子节点相应的建立为该节点下所有样本的y值平均值,预测的时候根据判断返回该叶子节点下y值平均值即可。

        在进行最佳待切分特征选取的时候,一般还有两个参数,一个是允许的误差下降值,一个是切分最小样本数。对于允许误差下降值,在实际过程中,需要在分割之后其误差减少应该至少大于该bound;对于切分最小样本数,也就是说切分后的子树中包含的样本数应该多于该bound。其实这两种策略都是为了避免过拟合。

4树剪枝

       

        通过在最佳待切分特征选取时进行参数设定来避免过拟合,这其实是一种预剪枝的行为;而在回归树建立后,再进行剪枝,则是一种后剪枝的行为。

        后剪枝的过程如下:

               如果存在任一子集是一棵树,则在该子集中递归剪枝

               计算当前两个叶子节点合并后的误差

               计算不合并的误差

              比较合并前后误差,如果合并后的误差降低,则对叶子节点进行合并

5模型树

        之前讲到误差、叶子节点和预测三者具备关联关系,当建立叶子节点是基于模型的,则构建了相应的模型树。这里可以使用之前的线性回归模型,建立相应的叶子节点。这样误差计算采用的将是线性回归中的误差,而预测则是基于该叶子节点拟合其样本后的参数。

6编程实现

         这里createTree负责进行树的构建;chooseBestSplit函数负责进行最佳带切特征的选取,而ops参数则是进行了两个bound的设定;prune进行了相关后剪枝。
         这里regErr、regLeaf、regTreeEval是基于简单均值计算的误差、叶子节点和预测;而modelErr、modelLeaf和modelTreeEval(+linearSolve)则是基于线性回顾模型的误差、叶子节点和预测。
         数据集链接:http://pan.baidu.com/share/link?shareid=3744521160&uk=973467359 密码:9ivd
[python] view plaincopyprint?
  1. from numpy import *  
  2. def loadDataSet(filename):  
  3.     dataMat = []  
  4.     fr = open(filename)  
  5.     for line in fr.readlines():  
  6.         curLine = line.strip('\n').split('\t')  
  7.         fltLine = map(float, curLine)  
  8.         dataMat.append(fltLine)  
  9.     fr.close()  
  10.     return dataMat  
  11. def regLeaf(dataSet):  
  12.     return mean(dataSet[:,-1])  
  13. def regErr(dataSet):  
  14.     return var(dataSet[:,-1])*shape(dataSet)[0]  
  15. def regTreeEval(model, inDat):  
  16.     return float(model)  
  17. def linearSolve(dataSet):  
  18.     m,n=shape(dataSet)  
  19.     X = mat(ones((m,n)))  
  20.     Y = mat(ones((m,1)))  
  21.     X[:,1:n]=dataSet[:,0:n-1]  
  22.     Y=dataSet[:,-1]  
  23.     xTx = X.T*X  
  24.     if linalg.det(xTx)==0.0:  
  25.         raise NameError('This matrix is singular, cannot do inverse, \  
  26.                try increasing the second value of ops')  
  27.     ws = xTx.T*(X.T*Y)  
  28.     return ws, X, Y  
  29. def modelLeaf(dataSet):  
  30.     ws, X, Y = linearSolve(dataSet)  
  31.     return ws  
  32. def modelErr(dataSet):  
  33.     ws,X,Y = linearSolve(dataSet)  
  34.     yHat = X*ws  
  35.     return sum(power(Y-yHat,2))  
  36. def modelTreeEval(model, inDat):  
  37.     n=shape(inDat)[1]  
  38.     X = mat(ones((1,n+1)))  
  39.     X[:,1:n+1]=inDat  
  40.     return float(X*model)  
  41. def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):  
  42.     tolS = ops[0]  
  43.     tolN = ops[1]  
  44.     if len(set(dataSet[:,-1].T.tolist()[0])) == 1:  
  45.         return None, leafType(dataSet)  
  46.     m,n=shape(dataSet)  
  47.     S = errType(dataSet)  
  48.     bestS = inf  
  49.     bestIndex = 0  
  50.     bestValue = 0  
  51.     for featIndex in range(n-1):  
  52.         for splitVal in set(dataSet[:,featIndex]):  
  53.             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)  
  54.             if(shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):  
  55.                 continue  
  56.             newS = errType(mat0)+errType(mat1)  
  57.             if newS < bestS:  
  58.                 bestIndex = featIndex  
  59.                 bestValue = splitVal  
  60.                 bestS = newS  
  61.     if (S-bestS)<tolS:  
  62.         return None, leafType(dataSet)  
  63.     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)  
  64.     if(shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):  
  65.         print "Not enough nums"  
  66.         return None, leafType(dataSet)  
  67.     return bestIndex, bestValue  
  68. def binSplitDataSet(dataSet, feature, value):  
  69.     mat0 = dataSet[nonzero(dataSet[:, feature]>value)[0],:][0]  
  70.     mat1 = dataSet[nonzero(dataSet[:, feature]<=value)[0],:][0]  
  71.     return mat0, mat1  
  72. def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):  
  73.     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)  
  74.     if feat == None:  
  75.         return val  
  76.     retTree={}  
  77.     retTree['spInd'] = feat  
  78.     retTree['spVal'] = val  
  79.     lSet, rSet = binSplitDataSet(dataSet, feat, val)  
  80.     retTree['left']=createTree(lSet, leafType, errType, ops)  
  81.     retTree['right']=createTree(rSet, leafType, errType, ops)  
  82.     return retTree  
  83. def isTree(obj):  
  84.     return (type(obj).__name__=='dict')  
  85. def getMean(tree):  
  86.     if isTree(tree['right']):  
  87.         tree['right'] = getMean(tree['right'])  
  88.     if isTree(tree['left']):  
  89.         tree['left'] = getMean(tree['left'])  
  90.     return (tree['left']+tree['right'])/2.0  
  91. def prune(tree, testData):  
  92.     if shape(testData)[0] == 0:  
  93.         return getMean(tree)  
  94.     if(isTree(tree['right']) or isTree(tree['left'])):  
  95.         lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])  
  96.     if isTree(tree['left']):  
  97.         tree['left']=prune(tree['left'],lSet)  
  98.     if isTree(tree['right']):  
  99.         tree['right']=prune(tree['right'],rSet)  
  100.     if not isTree(tree['right']) and not isTree(tree['left']):  
  101.         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])  
  102.         errorNoMerge = sum(power(lSet[:,-1]-tree['left'],2))+\  
  103.                        sum(power(rSet[:,-1]-tree['right'],2))  
  104.         treeMean = (tree['left']+tree['right'])/2.0  
  105.         errorMerge = sum(power(testData[:,-1]-treeMean,2))  
  106.         if errorMerge < errorNoMerge:  
  107.             print "Merging"  
  108.             return treeMean  
  109.         else:  
  110.             return tree  
  111.     else:  
  112.         return tree  
  113. def treeForeCast(tree, inData, modelEval=regTreeEval):  
  114.     if not isTree(tree):  
  115.         return modelEval(tree, inData)  
  116.     if inData[tree['spInd']]>tree['spVal']:  
  117.         if isTree(tree['left']):  
  118.             return treeForeCast(tree['left'], inData, modelEval)  
  119.         else:  
  120.             return modelEval(tree['left'],inData)  
  121.     else:  
  122.         if isTree(tree['right']):  
  123.             return treeForeCast(tree['right'], inData, modelEval)  
  124.         else:  
  125.             return modelEval(tree['right'], inData)  
  126. def createForeCast(tree, testData, modelEval=regTreeEval):  
  127.     m=len(testData)  
  128.     yHat = mat(zeros((m,1)))  
  129.     for i in range(m):  
  130.         yHat[i,0]=treeForeCast(tree, mat(testData[i]), modelEval)  
  131.     return yHat  
  132. ''''' 
  133. myData2 = loadDataSet(r"ex2.txt") 
  134. myMat2 = mat(myData2) 
  135. tree2 = createTree(myMat2, ops=(0,1)) 
  136. print tree2 
  137. myData2Test = loadDataSet(r"ex2test.txt") 
  138. myMat2Test = mat(myData2Test) 
  139. print prune(tree2, myMat2Test) 
  140. '''  
  141. trainMat = mat(loadDataSet('bikeSpeedVsIq_train.txt'))  
  142. testMat = mat(loadDataSet('bikeSpeedVsIq_test.txt'))  
  143. myregTree=createTree(trainMat, ops=(1,20))  
  144. mymodTree=createTree(trainMat, modelLeaf, modelErr, (1,20))  
  145. yregHat=createForeCast(myregTree, testMat[:,0])  
  146. ymodHat=createForeCast(mymodTree, testMat[:,0], modelTreeEval)  
  147. regCo = corrcoef(yregHat, testMat[:,1], rowvar=0)[0,1]  
  148. modCo = corrcoef(ymodHat, testMat[:,1], rowvar=0)[0,1]  
  149. print "reg", regCo  
  150. print "model", modCo  
0 0
原创粉丝点击