CART分类回归树-(python3)

来源:互联网 发布:sql注入语句怎么使用 编辑:程序博客网 时间:2024/05/16 05:54

一、树回归

1、简介

假设X与Y分别是输入和输出向量,并且Y是连续变量,给定训练数据集
考虑如何生成回归树。
一个回归树对应着输入空间(即特征空间)的一个划分以及在划分的但单元上的输出值。假设已将输入空间划分为M个单元
  
,并且在每个单元
  
上有一个固定的输出值
  
,于是回归树模型可表示为(简单来说就是把数据集划分为多份数据,且每份数据集里面的输出一致)
对固定输入变量
  
可以找到最优切分点
 (找到最小的平方误差的特征量)
遍历所有输入变量,找到最优切分变量
  
,构成一个对
  
,依次将输入空间划分为两个区域。接着,每个对每个区域重复上述划分过程,直到满足停止条件为止。这样就生成一棵回归树,这样的回归树通常被称为最小二乘树。[1] 

树回归的大致过程

(1)载入数据

#coding:utf-8from numpy import *def loadDataset(filename):    dataMat = []    fr = open(filename)    for line in fr.readlines():        curline = line.strip().split('\t')        fltline = map(float, curline)        # print(list(fltline))        dataMat.append(list(fltline))    return dataMat
(2)binSplitDataset函数切分数据集
def binSplitDataset(dataSet, feature, value):#以这一列的每个值为界限,大于它和小于它的值,返回的是以这个特征值为界限分割的数据集    mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]#返回索引,切割数据集    mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]    return mat0, mat1
(3)函数计算平方误差,均值
def regLeaf(dataSet):           #建立叶节点函数,value为所有y的均值    return mean(dataSet[:,-1])def regErr(dataset):    return var(dataset[:, -1]) * shape(dataset)[0]#y的方差×y的数量=平方误差

(4)选择最好的切割方式
def chooseBestsplit(dataset, leafType=regLeaf, errtype = regErr,ops=(1, 4)):#找到最好的分割叶子节点    tolS = ops[0]##允许的误差下降值    tolN = ops[1] #切分的最小样本数    #判断是否可以分开二叉树    # print(len(set(dataset[:, -1].T.tolist()[0])))#不是一下子分开,然后就是先分割整个数据集,然后分割左边,然后右边    if len(set(dataset[:, -1].T.tolist()[0])) == 1:  # #如果剩余特征值的数量等于1,不需要再切分直接返回,(退出条件1)        return None, leafType(dataset)    m, n = shape(dataset)#行列数    S = errtype(dataset)#计算平方差    bestS = inf    bestIndex = 0    bestValue = 0    for featIndex in range(n - 1):#特征索引        for splitVal in set((dataset[:, featIndex].T.A.tolist())[0]):  #每一列的每个值            mat0, mat1 = binSplitDataset(dataset, featIndex, splitVal)#整个数据集,第几列,那一列的每个值            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue#样本数最小限制            # print(errtype(mat0))            newS = errtype(mat0) + errtype(mat1)#计算平方误差            if newS < bestS:                bestIndex = featIndex                bestValue = splitVal                bestS = newS    if (S - bestS) < tolS:#如果切分后误差效果下降不大,则取消切分,直接创建叶结点        return None, leafType(dataset)    mat0, mat1 = binSplitDataset(dataset, bestIndex, bestValue)  # 按照保存的最佳分割来划分集合    # #判断切分后子集大小,小于最小允许样本数停止切分3    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):        return None, leafType(dataset)    # 返回最佳二元切割的bestIndex和bestValue    return bestIndex, bestValue#返回特征编号和用于切分的特征值
(5)构造回归树
def isTree(obj):    return (type(obj).__name__=='dict') #判断为字典类型返回true#返回树的所有分支的和的平均值def getMean(tree):    if isTree(tree['right']):#找到        tree['right'] = getMean(tree['right'])#得到的是有右子树的平均值    if isTree(tree['left']):        tree['left'] = getMean(tree['left'])#左子树的平均值    return (tree['left']+tree['right'])/2.0#返回的就是整个数对于输入的特征的所判断的均值
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering    feat, val = chooseBestsplit(dataSet, leafType, errType, ops)    #采用最佳分割,将数据集分成两个部分    if feat == None: return val     #递归结束条件    retTree = {}                    #建立返回的字典    retTree['spInd'] = feat    retTree['spVal'] = val    lSet, rSet = binSplitDataset(dataSet, feat, val)    #得到左子树集合和右子树集合    retTree['left'] = createTree(lSet, leafType, errType, ops)      #递归左子树    retTree['right'] = createTree(rSet, leafType, errType, ops)     #递归右子树    return retTree
树的剪枝

#树的后剪枝,def prune(tree, testData):#待剪枝的树和剪枝所需的测试数据    if shape(testData)[0] == 0:# 确认数据集非空        return getMean(tree)    #假设发生过拟合,采用测试数据对树进行剪枝    if (isTree(tree['right']) or isTree(tree['left'])): #左右子树非空        lSet, rSet = binSplitDataset(testData, tree['spInd'], tree['spVal'])#按照索引,和值分割数据集    if isTree(tree['left']):        tree['left'] = prune(tree['left'], lSet)    if isTree(tree['right']):        tree['right'] = prune(tree['right'], rSet)    #剪枝后判断是否还是有子树    if not isTree(tree['left']) and not isTree(tree['right']):#只要有一个空        lSet, rSet = binSplitDataset(testData, tree['spInd'], tree['spVal'])        #判断是否融合        errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \                       sum(power(rSet[:, -1] - tree['right'], 2))#未熔合的方差        treeMean = (tree['left'] + tree['right']) / 2.0#平均值        errorMerge = sum(power(testData[:, -1] - treeMean, 2))#查看融合后的方差        #如果合并后误差变小,融合,将两个叶子的均值作为节点        if errorMerge < errorNoMerge:            print("merging")            return treeMean        else:            return tree    else:        return tree

模型树
def linearSolve(dataSet):   #将数据集格式化为X Y    m,n = shape(dataSet)    X = mat(ones((m, n)))    Y = mat(ones((m, 1)))    X[:, 1:n] = dataSet[:, 0:n-1]#把x矩阵第一列全设置为1    Y = dataSet[:, -1]    xTx = X.T*X    # print(xTx)    if linalg.det(xTx) == 0.0: #X Y用于简单线性回归,需要判断矩阵可逆        raise NameError('This matrix is singular, cannot do inverse,\n\        try increasing the second value of ops')    ws = xTx.I * (X.T * Y)#正规方程    # print(ws)    return ws, X, Ydef modelLeaf(dataSet):#不需要切分时生成模型树叶节点    ws,X,Y = linearSolve(dataSet)    return ws #返回回归系数def modelErr(dataSet):#用来计算误差找到最佳切分    ws,X,Y = linearSolve(dataSet)    yHat = X * ws    # print(yHat)    return sum(power(Y - yHat, 2))










原创粉丝点击