误差、叶子节点和预测三者有相关的关联关系,一种相对简单的是误差采用的是y值均方差,叶子节点相应的建立为该节点下所有样本的y值平均值,预测的时候根据判断返回该叶子节点下y值平均值即可。
在进行最佳待切分特征选取的时候,一般还有两个参数,一个是允许的误差下降值,一个是切分最小样本数。对于允许误差下降值,在实际过程中,需要在分割之后其误差减少应该至少大于该bound;对于切分最小样本数,也就是说切分后的子树中包含的样本数应该多于该bound。其实这两种策略都是为了避免过拟合。
通过在最佳待切分特征选取时进行参数设定来避免过拟合,这其实是一种预剪枝的行为;而在回归树建立后,再进行剪枝,则是一种后剪枝的行为。
后剪枝的过程如下:
如果存在任一子集是一棵树,则在该子集中递归剪枝
计算当前两个叶子节点合并后的误差
计算不合并的误差
比较合并前后误差,如果合并后的误差降低,则对叶子节点进行合并
5模型树
之前讲到误差、叶子节点和预测三者具备关联关系,当建立叶子节点是基于模型的,则构建了相应的模型树。这里可以使用之前的线性回归模型,建立相应的叶子节点。这样误差计算采用的将是线性回归中的误差,而预测则是基于该叶子节点拟合其样本后的参数。
6编程实现
这里createTree负责进行树的构建;chooseBestSplit函数负责进行最佳带切特征的选取,而ops参数则是进行了两个bound的设定;prune进行了相关后剪枝。
这里regErr、regLeaf、regTreeEval是基于简单均值计算的误差、叶子节点和预测;而modelErr、modelLeaf和modelTreeEval(+linearSolve)则是基于线性回顾模型的误差、叶子节点和预测。
数据集链接:http://pan.baidu.com/share/link?shareid=3744521160&uk=973467359 密码:9ivd
- from numpy import *
- def loadDataSet(filename):
- dataMat = []
- fr = open(filename)
- for line in fr.readlines():
- curLine = line.strip('\n').split('\t')
- fltLine = map(float, curLine)
- dataMat.append(fltLine)
- fr.close()
- return dataMat
- def regLeaf(dataSet):
- return mean(dataSet[:,-1])
- def regErr(dataSet):
- return var(dataSet[:,-1])*shape(dataSet)[0]
- def regTreeEval(model, inDat):
- return float(model)
- def linearSolve(dataSet):
- m,n=shape(dataSet)
- X = mat(ones((m,n)))
- Y = mat(ones((m,1)))
- X[:,1:n]=dataSet[:,0:n-1]
- Y=dataSet[:,-1]
- xTx = X.T*X
- if linalg.det(xTx)==0.0:
- raise NameError('This matrix is singular, cannot do inverse, \
- try increasing the second value of ops')
- ws = xTx.T*(X.T*Y)
- return ws, X, Y
- def modelLeaf(dataSet):
- ws, X, Y = linearSolve(dataSet)
- return ws
- def modelErr(dataSet):
- ws,X,Y = linearSolve(dataSet)
- yHat = X*ws
- return sum(power(Y-yHat,2))
- def modelTreeEval(model, inDat):
- n=shape(inDat)[1]
- X = mat(ones((1,n+1)))
- X[:,1:n+1]=inDat
- return float(X*model)
- def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
- tolS = ops[0]
- tolN = ops[1]
- if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
- return None, leafType(dataSet)
- m,n=shape(dataSet)
- S = errType(dataSet)
- bestS = inf
- bestIndex = 0
- bestValue = 0
- for featIndex in range(n-1):
- for splitVal in set(dataSet[:,featIndex]):
- mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
- if(shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):
- continue
- newS = errType(mat0)+errType(mat1)
- if newS < bestS:
- bestIndex = featIndex
- bestValue = splitVal
- bestS = newS
- if (S-bestS)<tolS:
- return None, leafType(dataSet)
- mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
- if(shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):
- print "Not enough nums"
- return None, leafType(dataSet)
- return bestIndex, bestValue
- def binSplitDataSet(dataSet, feature, value):
- mat0 = dataSet[nonzero(dataSet[:, feature]>value)[0],:][0]
- mat1 = dataSet[nonzero(dataSet[:, feature]<=value)[0],:][0]
- return mat0, mat1
- def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
- feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
- if feat == None:
- return val
- retTree={}
- retTree['spInd'] = feat
- retTree['spVal'] = val
- lSet, rSet = binSplitDataSet(dataSet, feat, val)
- retTree['left']=createTree(lSet, leafType, errType, ops)
- retTree['right']=createTree(rSet, leafType, errType, ops)
- return retTree
- def isTree(obj):
- return (type(obj).__name__=='dict')
- def getMean(tree):
- if isTree(tree['right']):
- tree['right'] = getMean(tree['right'])
- if isTree(tree['left']):
- tree['left'] = getMean(tree['left'])
- return (tree['left']+tree['right'])/2.0
- def prune(tree, testData):
- if shape(testData)[0] == 0:
- return getMean(tree)
- if(isTree(tree['right']) or isTree(tree['left'])):
- lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])
- if isTree(tree['left']):
- tree['left']=prune(tree['left'],lSet)
- if isTree(tree['right']):
- tree['right']=prune(tree['right'],rSet)
- if not isTree(tree['right']) and not isTree(tree['left']):
- lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
- errorNoMerge = sum(power(lSet[:,-1]-tree['left'],2))+\
- sum(power(rSet[:,-1]-tree['right'],2))
- treeMean = (tree['left']+tree['right'])/2.0
- errorMerge = sum(power(testData[:,-1]-treeMean,2))
- if errorMerge < errorNoMerge:
- print "Merging"
- return treeMean
- else:
- return tree
- else:
- return tree
- def treeForeCast(tree, inData, modelEval=regTreeEval):
- if not isTree(tree):
- return modelEval(tree, inData)
- if inData[tree['spInd']]>tree['spVal']:
- if isTree(tree['left']):
- return treeForeCast(tree['left'], inData, modelEval)
- else:
- return modelEval(tree['left'],inData)
- else:
- if isTree(tree['right']):
- return treeForeCast(tree['right'], inData, modelEval)
- else:
- return modelEval(tree['right'], inData)
- def createForeCast(tree, testData, modelEval=regTreeEval):
- m=len(testData)
- yHat = mat(zeros((m,1)))
- for i in range(m):
- yHat[i,0]=treeForeCast(tree, mat(testData[i]), modelEval)
- return yHat
- ''
-
-
-
-
-
-
-
-
- trainMat = mat(loadDataSet('bikeSpeedVsIq_train.txt'))
- testMat = mat(loadDataSet('bikeSpeedVsIq_test.txt'))
- myregTree=createTree(trainMat, ops=(1,20))
- mymodTree=createTree(trainMat, modelLeaf, modelErr, (1,20))
- yregHat=createForeCast(myregTree, testMat[:,0])
- ymodHat=createForeCast(mymodTree, testMat[:,0], modelTreeEval)
- regCo = corrcoef(yregHat, testMat[:,1], rowvar=0)[0,1]
- modCo = corrcoef(ymodHat, testMat[:,1], rowvar=0)[0,1]
- print "reg", regCo
- print "model", modCo