机器学习笔记十二:分类与回归树CART

来源:互联网 发布:淘宝api推广 编辑:程序博客网 时间:2024/06/05 05:38

一.基本概念

CART是classification and regression tree的缩写,即分类与回归树.一般都是把CART放在决策树这部分学习的.这里单独讲一下CART,因为他和一些复杂的算法也有关系.
相比起学习CART,更重要的是学习CART里面的那种思想.

二.CART回归

Ⅰ.CART回归树概念

设我们有数据集D, XY分别代表输入和输出变量,其中Y是连续变量(回归模型),包含m个样本的数据集格式如下:
这里写图片描述
很重要的一个思想是:一个回归树本质代表着一个被划分的输入空间和这些划分单元上面各自的输出值.
详细一点来说,一个回归树能够将一个输出空间划分为K个部分这里写图片描述,并且在每个部分这里写图片描述上面有一个固定的输出值这里写图片描述,那么回归树模型可以表示为:
这里写图片描述
也就是说,只要模型正确的建立起来,只要知道输入向量属于哪个部分,就马上能够得到他的值了.(上式中的I是指示函数.)

要理解CART回归,首先需要知道什么是切分变量切分点.切分,顾名思义就是把输入空间划破的意思啦.

切分变量其实可以叫做切分特征.就是在某个特征上面进行选择.
切分点,就是一个值,这个值把一个集合分成两边.

切分特征和切分点共同决定了一个集合应该以怎么样的方式来切分.

这样的表述可能有一些难以理解.看一个例子就很容易懂了.比如下面有一组数据
这里写图片描述
(其中一行表示一个样本,一列表示一个特征.我就不啰嗦了,几乎约定俗成的事情)

当取第第0个特征(特征索引从0开始),当前特征值大于3的成为一部分,当前特征值小于等于3的成为另外一部分.就如下图所示.
在这里,切分特征就是0,切分点就是3.这两个就能能够决定一个集合的划分了.是不是很好懂?
这里写图片描述

同样的..取特征1,切分点为20把数据集切分为两部分.如下图,也是很好理解的.
这里写图片描述

那么理论一点来看的话,比如选择第j个特征和一个切分点s,就能够定义两个部分了.可以写成:
这里写图片描述

这个时候,你应该会想,这些特征与切分点的组合是不是可以有很多种呢?没错,每一种特征和切分点的组合都能够表示对于集合的一种划分,那么划分方式太多了.
所以,CART回归非常重要的就是寻找最优切分特征和最优切分点.
一提到最优的问题是不是就想到通过一个最小化一个优化函数来做了?其实这里也一样.
这里用的是平方误差:
这里写图片描述

更具体一点,这里采用遍历所有输入变量来找最优切分特征j和最优切分点s.即
这里写图片描述

通常这里认为一个区域上面的结果为该区域对应所有输出的均值.即
这里写图片描述

也就是说,对于上面找最优切分点的时候,这里写图片描述可以替换成为该区域上面的所有输出的均值.

找到最优切分(j,s)之后,切分就能够将集合切分成总损失最小的两个部分.对于切分出来的区域再重复递归这样的划分过程,直到满足条件为止.那么就生成了一棵回归树.

Ⅱ.算法细节

1.遍历输入空间,找到最优切分特征j与切分点s
这里写图片描述

2.对于得到的最优切分(j,s)划分区域,并且得到输出值
这里写图片描述

3.对两个子区域递归调用步骤1,2直到满足停止条件

4.输入空间划分为K个部分,并且在每个部分上面有一个固定的输出值,那么回归树模型可以表示为:
这里写图片描述

Ⅲ.实现

实现部分采用的数据集是机器学习实战中的数据集.代码则是按照自己的理解重新改写了一遍.

读取数据模块:data.py

import numpy as npdef loadData(filename):    dataSet=np.loadtxt(fname=filename,dtype=np.float32)    return dataSet

用numpy内置的读取txt文件的函数就行,方便快捷.这里就不多讲了.

CART核心模块:CART.py

import numpy as npimport matplotlib.pyplot as plt#split dataSet trough featureIndex and valuedef splitDataSet(dataset,featureIndex,value):    subDataSet0=dataset[dataset[:,featureIndex]<=value,:]    subDataSet1=dataset[dataset[:,featureIndex]>value,:]    return subDataSet0,subDataSet1#compute the regression Error in a data Setdef getError(dataSet):    error=np.var(dataSet[:,-1])*dataSet.shape[0]    return error#choose the best featureIndex and value in dataSetdef chooseBestSplit(dataSet,leastErrorDescent,leastNumOfSplit):    rows,cols=np.shape(dataSet)    #error in dataSet    Error=getError(dataSet)    #init some important value we want get    bestError=np.inf    bestFeatureIndex=0    bestValue=0    #search process    #every feature index    for featureIndex in range(cols-1):        #every value in dataSet of specific index        for value in set(dataSet[:,featureIndex]):            subDataSet0,subDataSet1=splitDataSet(dataSet,featureIndex,value)            #print("sub0",subDataSet0.shape[0])            #print("sub1", subDataSet1.shape[0])          #  print(subDataSet0)            if (subDataSet0.shape[0]<leastNumOfSplit) or (subDataSet1.shape[0]<leastNumOfSplit):                continue            #compute error            tempError=getError(subDataSet0)+getError(subDataSet1)            #print("tempError:",tempError)            if tempError<bestError:                bestError=tempError                bestFeatureIndex=featureIndex                bestValue=value           # print("BestError:", bestError)           # print("BestIndex:", bestFeatureIndex)           # print("BestValue:", bestValue)    if Error-bestError<leastErrorDescent:        return None,np.mean(dataSet[:,-1])    mat0,mat1=splitDataSet(dataSet,bestFeatureIndex,bestValue)    if (mat0.shape[0]<leastNumOfSplit) or (mat1.shape[0]<leastNumOfSplit):        return None,np.mean(dataSet[:,-1])    return bestFeatureIndex,bestValue#build treedef buildTree(dataSet,leastErrorDescent=1,leastNumOfSplit=4):    bestFeatureIndex,bestValue=chooseBestSplit(dataSet,leastErrorDescent,leastNumOfSplit)    #recursion termination    if bestFeatureIndex==None:        return bestValue    Tree={}    Tree["featureIndex"]=bestFeatureIndex    Tree["value"]=bestValue    #get subset    leftSet,rightSet=splitDataSet(dataSet,bestFeatureIndex,bestValue)    #recursive function    Tree["left"]=buildTree(leftSet,leastErrorDescent,leastNumOfSplit)    Tree["right"] = buildTree(rightSet, leastErrorDescent, leastNumOfSplit)    return Treedef isTree(tree):    return (type(tree).__name__=='dict')def predict(tree,x):    if x[tree["featureIndex"]]<tree["value"]:        if isTree(tree["left"]):            return predict(tree["left"],x)        else:            return tree["left"]    else:        if isTree(tree["right"]):            return predict(tree["right"],x)        else:            return tree["right"]

这里一个一个来讲这些函数.
splitDataSet(dataset,featureIndex,value)
在理论部分已经讲到,我们要划分数据集,只需要两个值,一个就是特征,另外就是指定的阈值.
这个函数的作用就是通过传入的特征和阈值,把数据集划分为两部分.理论部分例子的图就可以形象展示这个函数的作用.

getError(dataSet)
这个函数是用来得到误差的.说是误差,倒不如说是方差.因为理论部分已经给出了式子,其中的c是可以用平均值来替代的,也就是是,刚好是数据集上面的总的方差.

chooseBestSplit(dataSet,leastErrorDescent,leastNumOfSplit)
顾名思义,就是找最好的划分罗.
leastErrorDescent这个参数表示最小的下降误差,也就是说要是在某一刻,误差的下降小于这个值,函数就会退出,leastNumOfSplit表示最小的划分数量.当要划分的集合元素小于这个阈值时候,被认为是没有什么划分的意义了,函数也不会再运行.
然后函数遍历数据集上面所有的特征,与特征上面的所有值,以找到最好的特征索引和划分点返回.

测试文件:run.py

import numpy as npimport dataimport CARTdataMat1=data.loadData("../data/ex00.txt")dataMat2=data.loadData("../data/ex0.txt")'''print(dataMat.shape)print(np.shape(dataMat))e=CART.getError(dataMat)print(e)print(CART.getError(mat0))print(CART.getError(mat1))mat0,mat1=CART.splitDataSet(dataMat,0,0.5)print(mat0)print(mat1)print(mat0.shape)'''#bestIndex,bestValue=CART.chooseBestSplit(dataMat)#print(bestIndex,bestValue)#tree1tree1=CART.buildTree(dataMat1)print(tree1)#tree2tree2=CART.buildTree(dataMat2)print(tree2)x=[1.0,0.559009]print(CART.predict(tree2,x))

用来测试CART回归的运行代码.

import numpy as npimport dataimport CARTdataMat1=data.loadData("../data/ex00.txt")dataMat2=data.loadData("../data/ex0.txt")'''print(dataMat.shape)print(np.shape(dataMat))e=CART.getError(dataMat)print(e)print(CART.getError(mat0))print(CART.getError(mat1))mat0,mat1=CART.splitDataSet(dataMat,0,0.5)print(mat0)print(mat1)print(mat0.shape)'''#bestIndex,bestValue=CART.chooseBestSplit(dataMat)#print(bestIndex,bestValue)#tree1tree1=CART.buildTree(dataMat1)print(tree1)#tree2tree2=CART.buildTree(dataMat2)print(tree2)x=[1.0,0.559009]print(CART.predict(tree2,x))

结果:
这里写图片描述

4 0
原创粉丝点击