【机器学习实战 第九章】树回归 CART算法的原理与实现
来源:互联网 发布:酒店水牌通过网络发送 编辑:程序博客网 时间:2024/05/28 23:12
本文来自《机器学习实战》(Peter Harrington)第九章“树回归”部分,代码使用python3.5,并在jupyter notebook环境中测试通过,推荐clone仓库后run cell all就可以了。
github地址:https://github.com/gshtime/machinelearning-in-action-python3
转载请标明原文链接
1 原理
CART(Classification and Regression Trees,分类回归树)是决策树算法的一种,这种树构建算法既可以用于分类也可以用于回归。
它采用一种递归二元分割(recursive binary splitting)的技术,分割方法采用基于最小距离的基尼指数(分类树中)或最小平方残差(回归树中)等方法来估计函数的不纯度,从而将当前的样本集分为两个子样本集,使得生成的的每个非叶子节点都有两个分支。因此,CART算法生成的决策树是结构简洁的二叉树。
因此,CART的目标是:选择输入变量和那些变量上的分割点,直到创建出适当的树。在这个过程中,使用贪婪算法(greedy algorithm)选择使用哪个输入变量和分割点,以使成本函数(cost function)最小化。
1.1 CART回归树的原理
本文主要讲解CART回归树的原理及实现
现在关注一下回归树的 CART 算法的细节。简要来说,创建一个决策树包含两步:
把预测器空间,即一系列可能值 \(X_1,X_2,...,X_p\) 分成 \(J\) 个不同的且非重叠的区域 \(R_1,R_2,...,R_J\)。
对进入区域 \(R_J\) 的每一个样本观测值都进行相同的预测,该预测就是 \(R_J\) 中训练样本预测值的均值。
为了创建 \(J\) 个区域 \(R_1,R_2,...,R_J\),预测器区域被分为高维度的矩形或盒形。其目的在于通过下列式子找到能够使\(RSS\) 最小化的盒形区域 \(R_1,R_2,...,R_J\),
\[\sum_{j=1}^{J} \sum_{i \in R_j} \big(y_i - \hat{y}_{R_j}\big)^2\]
其中,\(\hat{y}_{R_j}\) 即是第 \(j\) 个盒形中训练观测的平均预测值。
鉴于这种空间分割在计算上是不可行的,因此我们常使用贪婪方法(greedy approach)来划分区域,叫做递归二元分割(recursive binary splitting)。
它是贪婪的(greedy),这是因为在创建树过程中的每一步骤,最佳分割都会在每个特定步骤选定,而不是对未来进行预测,并选取一个将会在未来步骤中出现且有助于创建更好的树的分割。注意所有的划分区域\(R_j,∀j∈[1,J]\) 都是矩形。为了进行递归二元分割,首先选取预测器 \(X_j\) (即数据集中的一个特征)和切割点 \(s\)(即该特征下某一个数据的值),递归遍历该特征下面所有的值作为二元分割的切割点,对预测器(特征)下的数据分割到不同的区域,即:\(R_1(j,s)=\big\{ X|Xj < s \big\} 和 R_2(j,s)=\big\{ X|Xj \ge s \big\}\),使得代价函数RSS得到最大程度的下降。从数学上讲,就是要寻找区域数J(我理解为叶节点数量)和分割点s,使分割后的代价函数最小化:
\[ \sum_{i: x_i \in R_1(j,s)} \big(y_i-\hat{y}_{R_1}\big)^2 + \sum_{i: x_i \in R_2(j,s)} \big(y_i-\hat{y}_{R_2}\big)^2 \]
其中 \(\hat{y}_{R_1}\) 为区域 \(R_1(j,s)\) 中观察样本的平均预测值,\(\hat{y}_{R_2}\) 为区域\(R_2(j,s)\) 的观察样本预测均值。这一过程不断重复以搜寻最好的预测器和切分点,并进一步分隔数据以使每一个子区域内的 RSS 最小化。然而,我们不会分割整个预测器空间,我们只会分割一个或两个前面已经认定的区域。这一过程会一直持续,直到达到停止准则,例如我们可以设定停止准则为每一个区域最多包含 m 个观察样本。一旦我们创建了区域\(R_1、R_2、...、R_J\),给定一个测试样本,我们就可以用该区域所有训练样本的平均预测值来预测该测试样本的值。
2 代码
2.1 CART回归树实现
代码比较长,不知道cnblogs中是否能折叠,为了方便复制,还是都放在一块吧,github中的代码是分开的,有需要可以去看。
原书regTrees.py部分的代码如下
# -*- coding: utf-8 -*-import numpy as npdef loadDataSet(fileName): ''' read the data file using TAB as separator,and store the data in float list ''' dataMat = [] fr = open(fileName) for line in fr.readlines(): curLine = line.strip().split('\t') fltLine = list(map(float, curLine)) dataMat.append(fltLine) return dataMatdef binSplitDataSet(dataSet, feature, value): mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:] mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:] return mat0, mat1def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): feat, val =chooseBestSplit(dataSet, leafType, errType, ops) if feat == None: return val retTree = {} retTree['spInd'] = feat retTree['spVal'] = val lSet, rSet = binSplitDataSet(dataSet, feat, val) retTree['left'] = createTree(lSet, leafType, errType, ops) retTree['right'] = createTree(rSet, leafType, errType, ops) return retTreedef regLeaf(dataSet): return np.mean(dataSet[:, -1])def regErr(dataSet): return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]#choose the best feature and splitting valuedef chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): tolS = ops[0] #tolerant value of S decilne tolN = ops[1] #min number of samples to be splitted if len(set(dataSet[:,-1].T.tolist()[0])) == 1: return None, leafType(dataSet) m,n = np.shape(dataSet) S = errType(dataSet) bestS = np.inf; bestIndex= 0; bestValue = 0 for featIndex in range(n-1): for splitVal in set(dataSet[:,featIndex].T.tolist()[0]): mat0, mat1 = binSplitDataSet(dataSet,featIndex, splitVal) if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue newS = errType(mat0) + errType(mat1) if newS < bestS: bestIndex = featIndex bestValue = splitVal bestS = newS #verdict whether the deciline of S reach the tolS or not if (S - bestS) < tolS: return None, leafType(dataSet) mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): return None, leafType(dataSet) return bestIndex, bestValuedef isTree(obj): return (type(obj).__name__=='dict')def getMean(tree): if isTree(tree['right']): tree['right'] = getMean(tree['right']) if isTree(tree['left']) : tree['left'] = getMean(tree['left']) return (tree['left'] + tree['right'])/2.0def prune(tree, testData): if np.shape(testData)[0] == 0: return getMean(tree) if(isTree(tree['right']) or isTree(tree['left'])): lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet) if isTree(tree['right']):tree['right']= prune(tree['right'],rSet) if not isTree(tree['left']) and not isTree(tree['right']): lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal']) errorNoMerge = np.sum(np.power(lSet[:,-1] - tree['left'], 2)) + np.sum(np.power(rSet[:,-1] - tree['right'], 2)) treeMean = (tree['left']+tree['right'])/2.0 errorMerge = np.sum(np.power(testData[:,-1] - treeMean, 2)) if errorMerge < errorNoMerge: print("merging") return treeMean else: return tree else: return treedef linearSolve(dataSet): m,n = np.shape(dataSet) X = np.mat(np.ones((m,n))) Y = np.mat(np.ones((m,1))) X[:,1:n] = dataSet[:,0:n-1] Y = dataSet[:,-1] xTx = X.T*X if np.linalg.det(xTx) == 0.0: raise NameError("This matrix is singular, cannot do inverse,\ntry increasing the second value of ops") ws = xTx.I * (X.T*Y) return ws, X , Ydef modelLeaf(dataSet): ws, X, Y = linearSolve(dataSet) return wsdef modelErr(dataSet): ws, X, Y = linearSolve(dataSet) yHat = X * ws return np.sum(np.power(Y - yHat, 2))def regTreeEval(model, inDat): return float(model)def modelTreeEval(model, inDat): n = np.shape(inDat)[1] X = np.mat(np.ones((1,n+1))) X[:,1:n+1] = inDat return float(X*model)def treeForecast(tree, inData, modelEval=regTreeEval): if not isTree(tree): return modelEval(tree, inData) if inData[tree['spInd']] > tree['spVal']: if isTree(tree['left']): return treeForecast(tree['left'], inData, modelEval) else: return modelEval(tree['left'], inData) else: if isTree(tree['right']): return treeForecast(tree['right'], inData, modelEval) else: return modelEval(tree['right'], inData)def createForecast(tree, testData, modelEval=regTreeEval): m = len(testData) yHat = np.mat(np.zeros((m,1))) for i in range(m): yHat[i,0] = treeForecast(tree, np.mat(testData[i]), modelEval) return yHat
2.2 使用python3的tkinter库创建GUI
python 2to3
原书的代码使针对python2.x环境构建的,在python2.x中应该import Tkinter
,而在python3.x中,应该import tkinter
才能正常导入Tkinter库
代码
# -*- coding:utf-8 -*-import tkinter as tkimport matplotlibmatplotlib.use('TkAgg')from matplotlib.backends.backend_tkagg import FigureCanvasTkAggfrom matplotlib.figure import Figuredef reDraw(tolS, tolN): reDraw.f.clf() reDraw.a = reDraw.f.add_subplot(111) if chkBtnVar.get(): if tolN < 2: tolN = 2 myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS,tolN)) yHat = createForecast(myTree, reDraw.testDat, modelTreeEval) else: myTree = createTree(reDraw.rawDat, ops=(tolS, tolN)) yHat = createForecast(myTree, reDraw.testDat) reDraw.a.scatter(reDraw.rawDat[:,0].A, reDraw.rawDat[:,1].A, s=5) reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) reDraw.canvas.show() def getInput(): try: tolN = int(tolNentry.get()) except: tolN = 10 print("enter Integet for tolN") tolNentry.delete(0, tk.END) tolNentry.insert(0, "10") try: tolS = float(tolSentry.get()) except: tolS = 1.0 print("enter Integet for tolS") tolNentry.delete(0, tk.END) tolNentry.insert(0, "1.0") return tolN, tolSdef drawNewTree(): tolN, tolS = getInput() reDraw(tolS, tolN)root = tk.Tk()#tk.Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)reDraw.f = Figure(figsize=(5,4), dpi=100)reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)reDraw.canvas.show()reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)tk.Label(root, text="tolN").grid(row=1, column=0)tolNentry = tk.Entry(root)tolNentry.grid(row=1, column=1)tolNentry.insert(0, '10')tk.Label(root, text="tolS").grid(row=2, column=0)tolSentry = tk.Entry(root)tolSentry.grid(row=2, column=1)tolSentry.insert(0, '1.0')tk.Button(root, text="ReDraw", command=drawNewTree).grid(row=1,column=2, rowspan=3)chkBtnVar = tk.IntVar()chkBtn = tk.Checkbutton(root, text="Model Tree", variable= chkBtnVar)chkBtn.grid(row=3, column=0, columnspan=2)reDraw.rawDat = np.mat(loadDataSet('./data/sine.txt'))reDraw.testDat = np.arange(np.min(reDraw.rawDat[:,0]), np.max(reDraw.rawDat[:,0]), 0.01)reDraw(1.0, 10)root.mainloop()
测试代码
测试的代码都在书里,我的github仓库里也有,有空我再放这儿吧
注意
有时候运行tkinter的时候,可能python会无限地崩溃,可以试一下重装matplotlib库来解决
参考资料
- https://zhuanlan.zhihu.com/p/28217071
这是一篇文章的中文翻译,推荐大家看看该文章的英文原文,这篇文章我觉得写得很棒,对了解CART有很大帮助,文中给出了借助sklearn库的CART实现方法,比较简单,另外作者给了其他决策树算法的文章链接。总之很推荐。 - http://blog.csdn.net/u014568921/article/details/45082197
写得比较仓促,自己也在理解和学习中,如果有不对的地方,还请多多指正。现在时间晚了,回头有空把这篇文章写得更全一点
- 【机器学习实战 第九章】树回归 CART算法的原理与实现
- 【机器学习实战 第九章】树回归 CART算法的原理与实现
- 机器学习实战 -ch09.树回归(CART算法)
- 机器学习实战-CART分类回归树
- 机器学习算法-分类回归树CART
- py2.7 : 《机器学习实战》树回归 3.8号 CART算法用于回归
- 机器学习实战之数回归,CART算法
- CART分类与回归树的原理与实现
- CART分类与回归树的原理与实现
- 机器学习实战 第九章 树回归 学习笔记
- 机器学习实战——第九章:树回归
- [完]机器学习实战 第九章 树回归
- 机器学习:决策树cart算法在分类与回归的应用(上)
- 机器学习:决策树cart算法在分类与回归的应用(下)
- 机器学习算法之CART(分类回归树)概要
- 机器学习算法之CART(分类和回归树)
- 机器学习经典算法详解及Python实现--CART分类决策树、回归树和模型树
- 机器学习经典算法详解及Python实现--CART分类决策树、回归树和模型树
- JavaScript中的==和===
- Leetcode 21. Merge Two Sorted Lists
- 方格分割
- 虚拟币开发专题(山寨币怎样通过挖矿最后把储存的币出完)
- python入门(三十):类的成员
- 【机器学习实战 第九章】树回归 CART算法的原理与实现
- JS实现复制功能,兼容各大主流浏览器复制神器 ZeroClipboard
- 使用序列化和反序列化实现深拷贝
- 关于《 MATLAB神经网络30个案例分析》坑的控诉
- js实现内容模块展开和收缩
- C++的引用与重载函数
- Github本地仓库与远程仓库使用心得
- Docker 使用总结
- xv6上下文切换代码