【机器学习实战 第九章】树回归 CART算法的原理与实现

来源:互联网 发布:粒子群算法电子书 编辑:程序博客网 时间:2024/05/19 16:37

本文来自《机器学习实战》(Peter Harrington)第九章“树回归”部分,代码使用python3.5,并在jupyter notebook环境中测试通过,推荐clone仓库后run cell all就可以了。

github地址:https://github.com/gshtime/machinelearning-in-action-python3

转载请标明原文链接

1 原理

CART(Classification and Regression Trees,分类回归树)是决策树算法的一种,这种树构建算法既可以用于分类也可以用于回归。

它采用一种递归二元分割(recursive binary splitting)的技术,分割方法采用基于最小距离的基尼指数(分类树中)或最小平方残差(回归树中)等方法来估计函数的不纯度,从而将当前的样本集分为两个子样本集,使得生成的的每个非叶子节点都有两个分支。因此,CART算法生成的决策树是结构简洁的二叉树。

因此,CART的目标是:选择输入变量和那些变量上的分割点,直到创建出适当的树。在这个过程中,使用贪婪算法(greedy algorithm)选择使用哪个输入变量和分割点,以使成本函数(cost function)最小化。

1.1 CART回归树的原理

本文主要讲解CART回归树的原理及实现

现在关注一下回归树的 CART 算法的细节。简要来说,创建一个决策树包含两步:

  1. 把预测器空间,即一系列可能值 \(X_1,X_2,...,X_p\) 分成 \(J\) 个不同的且非重叠的区域 \(R_1,R_2,...,R_J\)

  2. 对进入区域 \(R_J\) 的每一个样本观测值都进行相同的预测,该预测就是 \(R_J\) 中训练样本预测值的均值。

为了创建 \(J\) 个区域 \(R_1,R_2,...,R_J\),预测器区域被分为高维度的矩形或盒形。其目的在于通过下列式子找到能够使\(RSS\) 最小化的盒形区域 \(R_1,R_2,...,R_J\)
\[\sum_{j=1}^{J} \sum_{i \in R_j} \big(y_i - \hat{y}_{R_j}\big)^2\]

其中,\(\hat{y}_{R_j}\) 即是第 \(j\) 个盒形中训练观测的平均预测值。

鉴于这种空间分割在计算上是不可行的,因此我们常使用贪婪方法(greedy approach)来划分区域,叫做递归二元分割(recursive binary splitting)。

它是贪婪的(greedy),这是因为在创建树过程中的每一步骤,最佳分割都会在每个特定步骤选定,而不是对未来进行预测,并选取一个将会在未来步骤中出现且有助于创建更好的树的分割。注意所有的划分区域\(R_j,∀j∈[1,J]\) 都是矩形。为了进行递归二元分割,首先选取预测器 \(X_j\) (即数据集中的一个特征)和切割点 \(s\)(即该特征下某一个数据的值),递归遍历该特征下面所有的值作为二元分割的切割点,对预测器(特征)下的数据分割到不同的区域,即:\(R_1(j,s)=\big\{ X|Xj < s \big\} 和 R_2(j,s)=\big\{ X|Xj \ge s \big\}\),使得代价函数RSS得到最大程度的下降。从数学上讲,就是要寻找区域数J(我理解为叶节点数量)和分割点s,使分割后的代价函数最小化:
​​
\[ \sum_{i: x_i \in R_1(j,s)} \big(y_i-\hat{y}_{R_1}\big)^2 + \sum_{i: x_i \in R_2(j,s)} \big(y_i-\hat{y}_{R_2}\big)^2 \]

其中 \(\hat{y}_{R_1}\) 为区域 \(R_1(j,s)\) 中观察样本的平均预测值,\(\hat{y}_{R_2}\) 为区域\(R_2(j,s)\) 的观察样本预测均值。这一过程不断重复以搜寻最好的预测器和切分点,并进一步分隔数据以使每一个子区域内的 RSS 最小化。然而,我们不会分割整个预测器空间,我们只会分割一个或两个前面已经认定的区域。这一过程会一直持续,直到达到停止准则,例如我们可以设定停止准则为每一个区域最多包含 m 个观察样本。一旦我们创建了区域\(R_1、R_2、...、R_J\),给定一个测试样本,我们就可以用该区域所有训练样本的平均预测值来预测该测试样本的值。

2 代码

2.1 CART回归树实现

代码比较长,不知道cnblogs中是否能折叠,为了方便复制,还是都放在一块吧,github中的代码是分开的,有需要可以去看。

原书regTrees.py部分的代码如下

# -*- coding: utf-8 -*-import numpy as npdef loadDataSet(fileName):    '''    read the data file using TAB as separator,and store the data in float list    '''    dataMat = []    fr = open(fileName)    for line in fr.readlines():        curLine = line.strip().split('\t')        fltLine = list(map(float, curLine))        dataMat.append(fltLine)    return dataMatdef binSplitDataSet(dataSet, feature, value):    mat0 = dataSet[np.nonzero(dataSet[:,feature]  > value)[0],:]    mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]    return mat0, mat1def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):    feat, val =chooseBestSplit(dataSet, leafType, errType, ops)    if feat == None: return val    retTree = {}    retTree['spInd'] = feat    retTree['spVal'] = val    lSet, rSet = binSplitDataSet(dataSet, feat, val)    retTree['left'] = createTree(lSet, leafType, errType, ops)    retTree['right'] = createTree(rSet, leafType, errType, ops)    return retTreedef regLeaf(dataSet):    return np.mean(dataSet[:, -1])def regErr(dataSet):    return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]#choose the best feature and splitting valuedef chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):    tolS = ops[0] #tolerant value of S decilne    tolN = ops[1] #min number of samples to be splitted    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:        return None, leafType(dataSet)    m,n = np.shape(dataSet)    S = errType(dataSet)    bestS = np.inf;    bestIndex= 0;    bestValue = 0    for featIndex in range(n-1):        for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):            mat0, mat1 = binSplitDataSet(dataSet,featIndex, splitVal)            if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue            newS = errType(mat0) + errType(mat1)            if newS < bestS:                bestIndex = featIndex                bestValue = splitVal                bestS = newS    #verdict whether the deciline of S reach the tolS or not    if (S - bestS) < tolS:        return None, leafType(dataSet)    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)    if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):        return None, leafType(dataSet)    return bestIndex, bestValuedef isTree(obj):    return (type(obj).__name__=='dict')def getMean(tree):    if isTree(tree['right']): tree['right'] = getMean(tree['right'])    if isTree(tree['left']) : tree['left']  = getMean(tree['left'])    return (tree['left'] + tree['right'])/2.0def prune(tree, testData):    if np.shape(testData)[0] == 0: return getMean(tree)    if(isTree(tree['right']) or isTree(tree['left'])):        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)    if isTree(tree['right']):tree['right']= prune(tree['right'],rSet)    if not isTree(tree['left']) and not isTree(tree['right']):        lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])        errorNoMerge = np.sum(np.power(lSet[:,-1] - tree['left'], 2)) + np.sum(np.power(rSet[:,-1] - tree['right'], 2))        treeMean = (tree['left']+tree['right'])/2.0        errorMerge = np.sum(np.power(testData[:,-1] - treeMean, 2))        if errorMerge < errorNoMerge:            print("merging")            return treeMean        else: return tree    else: return treedef linearSolve(dataSet):    m,n = np.shape(dataSet)    X = np.mat(np.ones((m,n)))    Y = np.mat(np.ones((m,1)))    X[:,1:n] = dataSet[:,0:n-1]    Y = dataSet[:,-1]    xTx = X.T*X    if np.linalg.det(xTx) == 0.0:        raise NameError("This matrix is singular, cannot do inverse,\ntry increasing the second value of ops")    ws = xTx.I * (X.T*Y)    return ws, X , Ydef modelLeaf(dataSet):    ws, X, Y = linearSolve(dataSet)    return wsdef modelErr(dataSet):    ws, X, Y = linearSolve(dataSet)    yHat = X * ws    return np.sum(np.power(Y - yHat, 2))def regTreeEval(model, inDat):    return float(model)def modelTreeEval(model, inDat):    n = np.shape(inDat)[1]    X = np.mat(np.ones((1,n+1)))    X[:,1:n+1] = inDat    return float(X*model)def treeForecast(tree, inData, modelEval=regTreeEval):    if not isTree(tree): return modelEval(tree, inData)    if inData[tree['spInd']] > tree['spVal']:        if isTree(tree['left']):            return treeForecast(tree['left'], inData, modelEval)        else:            return modelEval(tree['left'], inData)    else:        if isTree(tree['right']):            return treeForecast(tree['right'], inData, modelEval)        else:            return modelEval(tree['right'], inData)def createForecast(tree, testData, modelEval=regTreeEval):    m = len(testData)    yHat = np.mat(np.zeros((m,1)))    for i in range(m):        yHat[i,0] = treeForecast(tree, np.mat(testData[i]), modelEval)    return yHat

2.2 使用python3的tkinter库创建GUI

python 2to3

原书的代码使针对python2.x环境构建的,在python2.x中应该import Tkinter,而在python3.x中,应该import tkinter才能正常导入Tkinter库

代码

# -*- coding:utf-8 -*-import tkinter as tkimport matplotlibmatplotlib.use('TkAgg')from matplotlib.backends.backend_tkagg import FigureCanvasTkAggfrom matplotlib.figure import Figuredef reDraw(tolS, tolN):    reDraw.f.clf()    reDraw.a = reDraw.f.add_subplot(111)        if chkBtnVar.get():        if tolN < 2: tolN = 2        myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS,tolN))        yHat = createForecast(myTree, reDraw.testDat, modelTreeEval)    else:        myTree = createTree(reDraw.rawDat, ops=(tolS, tolN))        yHat = createForecast(myTree, reDraw.testDat)            reDraw.a.scatter(reDraw.rawDat[:,0].A, reDraw.rawDat[:,1].A, s=5)    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)        reDraw.canvas.show()      def getInput():    try:        tolN = int(tolNentry.get())    except:        tolN = 10        print("enter Integet for tolN")        tolNentry.delete(0, tk.END)        tolNentry.insert(0, "10")    try:        tolS = float(tolSentry.get())    except:        tolS = 1.0        print("enter Integet for tolS")        tolNentry.delete(0, tk.END)        tolNentry.insert(0, "1.0")    return tolN, tolSdef drawNewTree():    tolN, tolS = getInput()    reDraw(tolS, tolN)root = tk.Tk()#tk.Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)reDraw.f = Figure(figsize=(5,4), dpi=100)reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)reDraw.canvas.show()reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)tk.Label(root, text="tolN").grid(row=1, column=0)tolNentry = tk.Entry(root)tolNentry.grid(row=1, column=1)tolNentry.insert(0, '10')tk.Label(root, text="tolS").grid(row=2, column=0)tolSentry = tk.Entry(root)tolSentry.grid(row=2, column=1)tolSentry.insert(0, '1.0')tk.Button(root, text="ReDraw", command=drawNewTree).grid(row=1,column=2, rowspan=3)chkBtnVar = tk.IntVar()chkBtn = tk.Checkbutton(root, text="Model Tree", variable= chkBtnVar)chkBtn.grid(row=3, column=0, columnspan=2)reDraw.rawDat = np.mat(loadDataSet('./data/sine.txt'))reDraw.testDat = np.arange(np.min(reDraw.rawDat[:,0]), np.max(reDraw.rawDat[:,0]), 0.01)reDraw(1.0, 10)root.mainloop()

测试代码

测试的代码都在书里,我的github仓库里也有,有空我再放这儿吧

注意

有时候运行tkinter的时候,可能python会无限地崩溃,可以试一下重装matplotlib库来解决

参考资料

  1. https://zhuanlan.zhihu.com/p/28217071
    这是一篇文章的中文翻译,推荐大家看看该文章的英文原文,这篇文章我觉得写得很棒,对了解CART有很大帮助,文中给出了借助sklearn库的CART实现方法,比较简单,另外作者给了其他决策树算法的文章链接。总之很推荐。
  2. http://blog.csdn.net/u014568921/article/details/45082197

写得比较仓促,自己也在理解和学习中,如果有不对的地方,还请多多指正。现在时间晚了,回头有空把这篇文章写得更全一点

阅读全文
0 0
原创粉丝点击
热门问题 老师的惩罚 人脸识别 我在镇武司摸鱼那些年 重生之率土为王 我在大康的咸鱼生活 盘龙之生命进化 天生仙种 凡人之先天五行 春回大明朝 姑娘不必设防,我是瞎子 淘宝店开了没做怎么办 微店店铺严重违规怎么办 淘宝违规扣2分怎么办 淘宝被扣6分怎么办 淘宝被扣2分怎么办 淘宝被海关扣了怎么办 淘宝被扣36分后怎么办 淘宝售假查封店铺资金怎么办 淘宝店扣48分怎么办 淘宝a内被扣48分怎么办 饿了么店铺满减怎么办 淘宝店扣a48分怎么办 淘宝短信营销无法获取人群怎么办 淘宝货发了退款怎么办 极速退款后卖家不确认收货怎么办 把货退了卖家不退款怎么办? 退款了又收到货怎么办 退货忘了填单号怎么办 手机换号了淘宝怎么办 换了手机支付宝怎么办 手机丢了微信登不上去了怎么办 前面手机丢了微信登不上去怎么办 淘宝密码忘了怎么办呢 融e借逾期一天怎么办 拼多多处罚下架怎么办 永久无法解绑支付宝怎么办 淘宝下单购买人数太多怎么办 新浪微博被拉黑暂时无法评论怎么办 闲鱼交易成功后卖家反悔怎么办 闲鱼买家不申请介入怎么办 支付宝安装不上怎么办 无线摄像机离wifi太远怎么办 安卓系统死机了怎么办 安卓手机开不了机怎么办 手机关机键坏了怎么办 华为手机接听电话声音小怎么办 小米6x游戏闪退怎么办 安卓8.0不兼容怎么办 安卓8.0应用闪退怎么办 安卓8.0不兼容的怎么办 游戏全屏只有一个分辨率选项怎么办