CART树回归

来源:互联网 发布:淘宝开网店流程步骤 编辑:程序博客网 时间:2024/05/19 20:44

说明:本博客是学习《python机器学习算法》赵志勇著的学习笔记,其图片截取也来源本书。

基于树的回归算法是一类基于局部的回归算法,通过将数据集切分成多份,在每一份数据中单独建模。与局部加权线性回归不同的是,基于树回归的算法是一种基于参数学习的算法,利用训练数据训练完模型后,参数一定确定,无需再改变。

分类回归树(Classification And Regression Tree,CART)算法是使用比较多的一种树模型,CART算法既可以解决分类问题也可以解决回归问题。前面的博客随机森林中有介绍CART算法处理分类的问题,在这次的博客中将涉及到CART的回归问题。CART树回归属于一种局部的回归算法,通过将全局的数据集划分成多份容易建模的数据集,这样在每一个局部的数据集上进行局部的回归建模。

CART算法中的树采用一种二分递归分割技术,即将当前的样本集分为左子树和右子树两个样本集,使得生成的每个非子叶节点都有两个分支。因此,CART算法生成的决策树是非典型的二叉树。

利用CART算法处理回归问题的主要步骤:1.CART会归树的生成;2.CART回归树的剪枝。

1、CART回归树的生成

1.1、CART回归树的划分

在CART分类算法中,利用Gini指数作为树的指标,通过样本中的特征,对样本进行划分,直到所有的叶节点的所有样本都为同样类别为止。但在CART回归树中,样本的标是一系列的连续值的集合,不能再使用Gini指数作为划分树的指标。同时,我们也注意到Gini指数是衡量数据的混乱程度的,对于连续的数据,当数据分布比较分散时,各个数据与平均数的差的平方和较大,方差就越大;反之,当数据分布比较集中时,各个数据与平均数的差的平方和较小,方差就越小,数据的波动就越小。因此,对于连续的数据,可以使用样本与平均值的平方和作为划分回归树的指标。
这里写图片描述

有了划分的标准,那如何划分数据呢?在CART中我们根据每一维特征中的每一个值,尝试将样本划分到树节点的左右子树中,如取样本特征中第j维特征中值x作为划分的值,如果一个样本在第j维初的值大于或者等于x,则将其划分到右子树中,具体划分过程如下图所示。

这里写图片描述
一般小于特征值在左子树,大于等于在右子树。

1.1、CART回归树的构建

CART分类树的构建过程如下所示:

1、对于当前序训练数据集,遍历所有属性及其所有可能的切分点,寻找最佳切分属性及其最佳切分点,使得切分之后的基尼指数最小,利用该最佳属性及其最佳切分点将数据集划分为两个子数据集,分别对应的结果就是左右两子树。

2、对第一步中生成的两个数据集递归调用第一步,直至满足条件为止。
3、生成CART决策树。

2、CART回归树剪枝

对CART进行剪枝的目的是防止CART出现过拟合。在剪枝中主要分为:前剪枝和后剪枝。

2.1、前剪枝

前剪枝是指在生成CART回归树可以通过设置最小的过程中对树的深度进行控制,防止生成过多的叶子节点。比如每次子树的最小样本数量和最小误差率。来控制是否进行更多的划分。

2.2、后剪枝

后剪枝是指将训练样本分为两个部分,一部分用来训练CART树模型,这部分数据称为训练数据,另一部分用来对生成的CART树模型进行剪枝,这一部分称为验证数据。

由上述过程可知,在后剪枝的过程中,通过验证生成好的CART树模型是否在验证数据集上发生了过拟合,如果出现了过拟合的现象,则合并一些叶子节点来达到对CART树模型的剪枝。

import numpy as np# import cPickle as pickleclass node:    '''树的节点的类    '''    def __init__(self, fea=-1, value=None, results=None, right=None, left=None):        self.fea = fea  # 用于切分数据集的属性的列索引值        self.value = value  # 设置划分的值        self.results = results  # 存储叶节点的值        self.right = right  # 右子树        self.left = left  # 左子树def load_data(data_file):    '''导入训练数据    input:  data_file(string):保存训练数据的文件    output: data(list):训练数据    '''    data = []    f = open(data_file)    for line in f.readlines():        sample = []        lines = line.strip().split("\t")        for x in lines:            sample.append(float(x))  # 转换成float格式        data.append(sample)    f.close()    return datadef split_tree(data, fea, value):    '''根据特征fea中的值value将数据集data划分成左右子树    input:  data(list):训练样本            fea(float):需要划分的特征index            value(float):指定的划分的值    output: (set_1, set_2)(tuple):左右子树的聚合    '''    set_1 = []  # 右子树的集合    set_2 = []  # 左子树的集合    for x in data:        if x[fea] >= value:            set_1.append(x)        else:            set_2.append(x)    return (set_1, set_2)def leaf(dataSet):    '''计算叶节点的值    input:  dataSet(list):训练样本    output: np.mean(data[:, -1])(float):均值    '''    data = np.mat(dataSet)    return np.mean(data[:, -1])def err_cnt(dataSet):    '''回归树的划分指标    input:  dataSet(list):训练数据    output: m*s^2(float):总方差    '''    data = np.mat(dataSet)    return np.var(data[:, -1]) * np.shape(data)[0]def build_tree(data, min_sample, min_err):    '''构建树    input:  data(list):训练样本            min_sample(int):叶子节点中最少的样本数            min_err(float):最小的error    output: node:树的根结点    '''    # 构建决策树,函数返回该决策树的根节点    if len(data) <= min_sample:        return node(results=leaf(data))    # 1、初始化    best_err = err_cnt(data)    bestCriteria = None  # 存储最佳切分属性以及最佳切分点    bestSets = None  # 存储切分后的两个数据集    # 2、开始构建CART回归树    feature_num = len(data[0]) - 1    for fea in range(0, feature_num):        feature_values = {}        for sample in data:            feature_values[sample[fea]] = 1        for value in feature_values.keys():            # 2.1、尝试划分            (set_1, set_2) = split_tree(data, fea, value)            if len(set_1) < 2 or len(set_2) < 2:                continue            # 2.2、计算划分后的error值            now_err = err_cnt(set_1) + err_cnt(set_2)            # 2.3、更新最优划分            if now_err < best_err and len(set_1) > 0 and len(set_2) > 0:                best_err = now_err                bestCriteria = (fea, value)                bestSets = (set_1, set_2)    # 3、判断划分是否结束    if best_err > min_err:        right = build_tree(bestSets[0], min_sample, min_err)        left = build_tree(bestSets[1], min_sample, min_err)        return node(fea=bestCriteria[0], value=bestCriteria[1], \                    right=right, left=left)    else:        return node(results=leaf(data))  # 返回当前的类别标签作为最终的类别标签def predict(sample, tree):    '''对每一个样本sample进行预测    input:  sample(list):样本            tree:训练好的CART回归树模型    output: results(float):预测值    '''    # 1、只是树根    if tree.results != None:        return tree.results    else:    # 2、有左右子树        val_sample = sample[tree.fea]  # fea处的值        branch = None        # 2.1、选择右子树        if val_sample >= tree.value:            branch = tree.right        # 2.2、选择左子树        else:            branch = tree.left        return predict(sample, branch)def cal_error(data, tree):    ''' 评估CART回归树模型    input:  data(list):            tree:训练好的CART回归树模型    output: err/m(float):均方误差    '''    m = len(data)  # 样本的个数       n = len(data[0]) - 1  # 样本中特征的个数    err = 0.0    for i in range(m):        tmp = []        for j in range(n):            tmp.append(data[i][j])        pre = predict(tmp, tree)  # 对样本计算其预测值        # 计算残差        err += (data[i][-1] - pre) * (data[i][-1] - pre)    return err / m# def save_model(regression_tree, result_file):#     '''将训练好的CART回归树模型保存到本地#     input:  regression_tree:回归树模型#             result_file(string):文件名#     '''#     with open(result_file, 'w') as f:#         pickle.dump(regression_tree, f)if __name__ == "__main__":    # 1、导入训练数据    print("----------- 1、load data -------------")    data = load_data("C:\\Python-Machine-Learning-Algorithm-master\\Chapter_9 CART\\sine.txt")    # 2、构建CART树    print("----------- 2、build CART ------------")    regression_tree = build_tree(data, 30, 0.3)    # 3、评估CART树    print("----------- 3、cal err -------------")     err = cal_error(data, regression_tree)    print("\t--------- err : ", err)#     # 4、保存最终的CART模型#     print "----------- 4、save result -----------"  #     save_model(regression_tree, "regression_tree")
----------- 1、load data ------------------------ 2、build CART ----------------------- 3、cal err -------------    --------- err :  0.017472194888