回归树-----生成回归树

来源:互联网 发布:java常见面试题及答案 编辑:程序博客网 时间:2024/06/06 16:06

当数据拥有众多属性并且属性间关系复杂时,前面所讲的回归算法就显得太难了。今天我们就讨论一种树型的回归算法。前面讲过一个树,叫做决策树,构建决策树时需要利用信息增益来计算出最佳的分类特征然后不断的从剩余的特征中找出最佳的分类特征进行分类,这种方法叫做ID3.今天我们构建树所用的切分数据的方法有所不同,用的是二分法。其主体思想也是找到最佳的分类特征:ex00.txt (提取码:b416)

from numpy import *def loadDataSet(fileName):         dataMat = []                    fr = open(fileName)    for line in fr.readlines():        curLine = line.strip().split('\t')        fltLine = map(float,curLine)        dataMat.append(fltLine)    return dataMatdef binSplitDataSet(dataSet, feature, value):    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]    return mat0,mat1def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)     if feat == None: return val     retTree = {}     retTree['spInd'] = feat     retTree['spVal'] = val     lSet, rSet = binSplitDataSet(dataSet, feat, val)     retTree['left'] = createTree(lSet, leafType, errType, ops)     retTree['right'] = createTree(rSet, leafType, errType, ops)     return retTree 
上面代码的第一个函数依旧是文件读写。第二个函数有三个输入参数,数据集合,待切分的特征和该特征的值,通过给定的特征以及特征值将数据集合切分到两个子集中,第三个函数用于建树,在这个函数有四个输入参数,数据集、建立叶子节点函数、误差计算函数和一个构建树所需的元组参数首先会用到切分函数对数据集进行切分然后判断是否满

足停止条件接下来递归调用函数来构建树。下面我们着重讲解数据切分:

def regLeaf(dataSet):    return mean(dataSet[:,-1])def regErr(dataSet):    return var(dataSet[:,-1]) * shape(dataSet)[0]def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):     tolS = ops[0]; tolN = ops[1]     if len(set(dataSet[:,-1].T.tolist()[0])) == 1:         return None, leafType(dataSet)     m,n = shape(dataSet)     S = errType(dataSet)     bestS = inf; bestIndex = 0; bestValue = 0     for featIndex in range(n-1):         for splitVal in set(dataSet[:,featIndex]):             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)             if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue             newS = errType(mat0) + errType(mat1)             if newS < bestS:                  bestIndex = featIndex                 bestValue = splitVal                 bestS = newS     if (S - bestS) < tolS:          return None, leafType(dataSet)     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)     if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):         return None, leafType(dataSet)     return bestIndex,bestValue
第一个函数负责生成叶子节点,第二个函数计算出错误率。这个算法的关键部分是第三个函数,数据切分函数,这个函数有四个输入参数,数据集、叶子节点生成函数、错误率计算函数和元组参数,这个元组的第一个数是容许的误差下降值,第二个参数是切分的最少样本数,这两个值用户可以自己修改。接下来判断所有值是否相等,如果相等就退出接下来看两个for循环,通过遍历整个数据集找出最佳的特征和特征值,然后判断误差减小是否小于阀值,如果小于阀值就退出。另一种退出条件是切分数据集是否过小,如果很小同样退出。

0 0
原创粉丝点击