第9章 机器学习实战之树回归

来源:互联网 发布:网页铃声制作软件 编辑:程序博客网 时间:2024/05/15 14:03

主要内容:
CART算法
回归与模型树
树剪枝算法

主要内容
● CART算法
● 回归与模型树
● 树剪枝算法
● python中GUI的使用

当数据有众多特征的时候且特征之间的关系十分复杂的时候,构建全局变量的想法就太难了。生活中许多实际问题都是非线性的,不可能全部使用全局线性模型来拟合数据。
我们可以利用树回归和回归法来切分数据,如果首次切分难以拟合模型就继续切分。

9.1 复杂数据的局部线性模型
决策数不断地将数据切分为小数据集,直到所有的目标变量完全相同,或者数据不在能切分为止。
决策树是一种贪心算法,要在给定的时间内做出最佳的选择,但不关心能否达到全局最优。

树回归
优点:可以对复杂和非线性的数据建模
缺点:结果不易理解
适用数据类型:数值型和标称型数据

ID3的做法是每次选取最佳的特征来分割数据,并按照该特征的所有可能值来切分。ID3算法的缺点就是,其一切分过去迅速,其二,不能直接的处理连续型特征。
CART算法使用的是二元切分来处理连续型变量,稍作修改就可以处理回归问题。

树回归的一般方法:
(1)收集数据:采用任意方法收集数据
(2)准备数据:需要数值型的数据,标称型数据应该映射成二值型数据
(3)分析数据:绘出数据的二维可视化显示结果,以字典的方式生成树
(4)训练算法:大部分时间都花费在叶节点树模型的构建上
(5)测试算法:使用测试数据上的R^2值来分析模型的效果
(6)使用算法:使用训练出的树做模型,预测结果还可以用来做许多事情

9.2 连续喝离散型特征的树的构建
在树的构建过程中,需要解决多种类型数据的存储问题。这里使用一部字典来存储树的数据结构。
字典包含以下4个元素:
● 待切分的特征
● 待切分的特征值
● 右子树。当不在需要切分的时候,也可是单个值
● 左子树。与右子树类似。

CART算法做二元切分,所以这里可以固定树的数据结构。树的左键和右键可以存储在另一棵子树或者单个值。字典还包括特征和特征值这两个键,他们给出切分算啊所有的特征和特征值。

函数createTree()的伪代码如下:

#伪代码"""找到最佳的待切分特征:    如果该节点不能再分,将该节点存为叶节点    执行二元切分    在右子树调用createTree() 方法    在左子树调用createTree() 方法"""
# CART算法实现from numpy import *import numpy as npdef loadDataSet(fileName):    """    读取tab键分隔符的文件将每行的内容保存成一组浮点数    """    dataMat = []    fr = open(fileName)    for line in fr.readlines():        curLine = line.strip().split('\t')        fltLine = map(float,curLine)        dataMat.append(fltLine)    return dataMatdef binSplitDataSet(dataSet,feature,value):    """    参数:数据集合,待切分的特征和该特征的某个值    在给定特征和特征值,函数通过数组过滤方式将数据切分得到两个子集    """    mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:][0]    mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:][0]    return mat0,mat1def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):    """    leafType=regLeaf 建立节点的函数    errType = regErr  代表计算误差的函数    ops=(1,4)  包含树所构建所需其他参数的元组    """    feat,val = chooseBestSplit(dataSet,leafType,errType,ops)    if feat == None:        return val    retTree = {}    retTree['spInd'] = feat    retTree['spVal'] = val    lSet, rSet = binSplitDataSet(dataSet,feat,val)    retTree['left'] = createTree(lSet,leafType,errType,ops)    retTree['right'] = createTree(lSet,leafType,errType,ops)    return retTree

1.loadDataSet()函数载入数据,读取一个tab键为分隔符的文件,将每行内容保存成一组浮点数。
2.binSplitDataSet()在给定特征和特征值得情况下,该函数通过数组过滤的方式将上述数据集合切分得到两个子集并返回。
3.createTree()是一个递归函数。调用函数binSplitDataSet()完成函数切分。

In [5]: testMat = np.mat(np.eye(4))In [7]: testMatOut[7]:matrix([[ 1., 0., 0., 0.],[ 0., 1., 0., 0.],[ 0., 0., 1., 0.],[ 0., 0., 0., 1.]])In [8]: mat0,mat1 = regTree.binSplitDataSet(testMat, 1, 0.5)In [9]: mat0Out[9]: matrix([[ 0., 1., 0., 0.]])In [10]: mat1Out[10]: matrix([[ 1., 0., 0., 0.]])

9.3 将CART算法用于回归
回归树假设叶节点是常数值,这种策略认为数据中的复杂关系可以用树结构来概括。

9.3.1 构建树
chooseBestSplit()函数:用最佳方式切分数据集和生成相应的叶节点。
函数chooseBestSplit()函数的伪代码如下:

##伪代码        """    对每个特征:        对每个特征            将数据集切分成两份            计算切分的误差            如果当前误差小于当前最小误差                那么将当前切分设定为最佳切分并更新最小误差    返回最佳的切分的特征和阈值    """
def regLeaf(dataSet):    return np.mean(dataSet[:,-1])  #负责生产叶节点def regErr(dataSet):    """    误差估计函数,计算目标变量的平均误差,调用均差函数var    """    return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]
##回归树切分函数def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr, ops = (1,4)):    """    构建回归树的核心函数,找到最佳的二元切分方式    tolS 容许的误差下降值    tolN 切分的最少样本数    leafType 是对创建叶节点的函数的引用    errType 是对总方差的计算函数的引用    ops 是一个用户自定义的参数构成的元组,用以完成树的构建    """    tolS = ops[0]; tolN = ops[1]    if len(set(dataSet[:,-1].T.tolist()[0])) ==1:        return None,leafType(dataSet)   #如果所有值相等则退出    m, n = np.shape(dataSet)    S = errType(dataSet)    bestS = inf; bestIndex = 0; bestValue = 0    for featIndex in range(n-1):        for splitVal in set(dataSet[:,featIndex]):            mat0,mat1 = binSplitDataSet(dataSet, featIndex, splitVal)            if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):                continue            newS = errType(mat0) + errType(mat1)            if newS < bestS:                bestIndex = featIndex                bestValue = splitVal                bestS = newS    if (S - bestS) < tolS:        return None, leafType(dataSet)  #如果误差不大则退出    mat0,mat1 = binSplitDataSet(dataSet,bestIndex,bestValue)    if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0]) < tolN:        return None, leafType(dataSet)  #如果切分的数据集很小则退出    return bestIndex, bestValue

最佳切分也就是使得切分贵阳能达到最低误差的切分。如果切分数据集后效果提升不大,那么就不应该进行切分操作,而应该直接创建叶节点。

9.3.2 运行代码

1.给的代码,运行会产生错误:

  File "regTree.py", line 109, in chooseBestSplit    for splitVal in set(dataSet[:,featIndex]):TypeError: unhashable type: 'matrix'

修改为:

for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0])

2.

  File "regTree.py", line 49, in binSplitDataSet    mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:][0]IndexError: index 0 is out of bounds for axis 0 with size 0

函数修改两行,正确结果为:

mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :] mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
In [39]: reload(regTree)Out[39]: <module 'regTree' from 'regTree.py'>In [40]: myDat = regTree.loadDataSet(r'E:\ML\ML_source_code\mlia\Ch09\ex00.txt')    ...:    ...: myMat = np.mat(myDat)    ...:In [41]: regTree.createTree(myMat)Out[41]:{'left': 1.0180967672413792,'right': 1.0180967672413792,'spInd': 0,'spVal': 0.48813}In [46]: myDat1 = regTree.loadDataSet(r'E:\ML\ML_source_code\mlia\Ch09\ex0.txt')    ...:In [48]: myMat1 = np.mat(myDat1)    ...:In [49]: regTree.createTree(myMat1)    ...:Out[49]:{'left': {'left': {'left': 3.9871631999999999,'right': 3.9871631999999999,'spInd': 1,'spVal': 0.797583},'right': {'left': 3.9871631999999999,'right': 3.9871631999999999,'spInd': 1,'spVal': 0.797583},'spInd': 1,'spVal': 0.582002},'right': {'left': {'left': 3.9871631999999999,'right': 3.9871631999999999,'spInd': 1,'spVal': 0.797583},'right': {'left': 3.9871631999999999,'right': 3.9871631999999999,'spInd': 1,'spVal': 0.797583},'spInd': 1,'spVal': 0.582002},'spInd': 1,'spVal': 0.39435}

画出ex00.txt 和 ex0.txt 多次切分之后的图:

这里写图片描述

这里写图片描述

9.4 树剪枝
一棵树如果节点过多,表明该模型可能对数据进行了“过拟合”,我们需要 使用测试机上某种交叉验证技术来发现是否过“拟合”。通过降低决策树的复杂度来避免过拟合来避免拟合的过程称之为 剪枝。函数chooseBestSplit()实际上就是一种预剪枝技术。
另一种形式的剪枝需要使用测试集和训练集,称之为后剪枝。

9.4.1 预剪枝
在9.3构建算法时候,算法对输入参数tols和参数tolN很敏感。

In [56]: myDat2 = regTrees.loadDataSet(r'E:\ML\ML_source_code\mlia\Ch09\ex2.txt')    ...:In [57]: myMat2 = np.mat(myDat2)    ...:In [58]: regTree.createTree(myMat2)    ...:Out[58]:{'left': {'left': {'left': {'left': 105.24862350000001,'right': 105.24862350000001,'spInd': 0,'spVal': 0.958512},'right': {'left': 105.24862350000001,......'spVal': 0.952833},'spInd': 0,'spVal': 0.729397},'spInd': 0,'spVal': 0.499171}In [60]: plt.plot(myMat2[:,0],myMat2[:,1],'ro')    ...: plt.show()

这里写图片描述

上图构造出的树只有两个节点,这里构建的新树则有很多节点。这是因为tols对误差的数量级十分敏感。
若对上述误差的容忍度取平方值,或许可以得到两个叶节点组成的树。

In [61]: regTree.createTree(myMat2,ops=(10000, 4))Out[61]:{'left': 101.35815937735848,'right': 101.35815937735848,'spInd': 0,'spVal': 0.499171}

不断地修改停止条件来得到合理的结果不是很好的方法。

9.4.2 后剪枝
使用后剪枝的方法需要将数据集分成测试集和训练集。

"""基于已有的树切分测试数据:    如果存下任一子集是一棵树,则在该子集递归剪枝过程    计算将当前两个叶节点合并后的误差    计算不合并的误差    如果合并会降低误差的话,就将叶节点合并"""
# 回归树剪枝函数def isTree(obj):    return (type(obj).__name__ =='dict')def getMean(tree):    if isTree(tree['right']):tree['right'] = getMean(tree['right'])    if isTree(tree['left']):tree['left'] = getMean(tree['right'])    return (tree['left'] + tree['right'])/2.0def prune(tree, testData):    """    tree:待剪枝的树    testData:剪枝所需要的测试数据    """    if np.shape(testData)[0] == 0: return getMean(tree) #确认数据集是否为空    if (isTree(tree['right']) or isTree(tree['left'])):        lSet, rSet = binSplitDataSet(testData,tree['spInd'], tree['spVal'])    if isTree(tree['left']):        tree['left'] = prune(tree['left'], lSet)    if isTree(tree['right']):        tree['right'] = prune(tree['right'],rSet)    if not isTree(tree['left']) and not isTree(tree['right']):        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])        errorNoMerge = sum(np.power(testData[:, -1] - tree['left'], 2)) + sum(power(rSet[:,-1] - tree['right'],2))         treeMean = (tree['left'] + tree['right'])/2.0        errorMerge = sum(np.power(testData[:, -1] - treeMean, 2))         if errorMerge < errorNoMerge:            print "merging"            return treeMean        else:            return tree    else:        return tree

isTree() 用于测试输入变量是否是一棵树,返回布尔型结果
getMean() 是一个递归函数,他从上到下遍历树直到节点为止,若找到两个叶节点就计算他们的平均值。该函数对树进行塌陷处理(返回树的平均值)。
主函数prune(),先确认测试集是否为空。则反复调用递归函数prune()对测试数据集进行切分。

In [18]: myDataTest = regTree.loadDataSet(r'E:\ML\ML_source_code\mlia\Ch09\ex2test.txt')In [19]: myMat2Test = np.mat(myDataTest)In [21]: regTree.prune(myTree, myMat2Test)mergingmergingmerging......{'left': {'left': {'left': {'left': {'left': 86.399636999999998,'right': 86.399636999999998,'spInd': 0,'spVal': 0.968621},'right': 86.399636999999998,'spInd': 0, ......'spVal': 0.952833},'spInd': 0,'spVal': 0.729397},'spInd': 0,'spVal': 0.499171}

9.5 模型树
把叶节点设定为分段线性函数,分段线性是指模型由多个线段片组成。
这里写图片描述

用两条直线肯定比一组常数model效果更好,可以由0.0~0.3和0.3~1.0的两条直线组成。决策树相比其他机器学习算法易于理解,而模型树的可解释性是它优于回归树的特性之一。模型树同时具备更高的预测准确度。

回归树把前面两个参数略做修改就可以用于模型树。

##模型树的叶节点生成函数def lineearSolve(dataSet):    m, n = np.shape(dataSet)    X = np.mat(ones((m,n))); Y = np.mat(ones((m,1)))    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:, -1]    xTx = X.T * X    if np.linalg.det(xTx) == 0.0: #判断矩阵是否可逆        raise NameError('This matrix is singular, connot do inverse, \n\        try increasing the second value of ops')    ws = xTx.I * (X.T * Y)    return ws,X,Ydef moudleLeaf(dataSet):    ws, X, Y = lineearSolve(dataSet)    return wsdef moudleErr(dataSet):    ws, X, Y = lineearSolve(dataSet)    yHat = X * ws    return sum(np.power(Y - yHat, 2))

1.lineearSolve()函数将数据集格式化为目标变量X,和目标变量Y。X和Y用于执行简单的线性回归。
2.moudleErr()在给定的数据集上计算误差
3.moudleLeaf()当数据不需要切分的时候生成叶节点的模型。

In [31]: reload(regTree)Out[31]: <module 'regTree' from 'regTree.py'>In [32]: myMat2 = np.mat(regTree.loadDataSet(r'E:\ML\ML_source_code\mlia\Ch09\exp2.txt'))In [34]: regTree.createTree(myMat2, regTree.moudleLeaf, regTree.moudleErr, (1,10))Out[34]:{'left': matrix([[  1.69855694e-03],         [  1.19647739e+01]]), 'right': matrix([[ 3.46877936],         [ 1.18521743]]), 'spInd': 0, 'spVal': 0.285477}In [36]: plt.plot(myMat2[:,0],myMat2[:,1],'ro')    ...: plt.show()

这里写图片描述

从结果看,生成了 y = 3.468+1.852 和y=3.468+1.185x和y=0.00169+11.964x两个线性模型。 与用于生成的数据相比(y=3.5+1.0x和y=0+12x)比较贴近

9.6 树回归与标准回归的比较
之前介绍了 模型树、回归树、一般的回归方法,下面测试哪个模型好。给出的函数来计算三个模型的误差。

#回归树In [58]: reload(regTree)In [58]: <module 'regTree' from 'regTree.py'>In [59]:In [59]: trainMat= np.mat(regTree.loadDataSet(r'E:\ML\ML_source_code\mlia\Ch09\bikeSpeedVsIq_train.txt'))In [60]: testMat= np.mat(regTree.loadDataSet(r'E:\ML\ML_source_code\mlia\Ch09\bikeSpeedVsIq_test.txt'))In [61]: myTree = regTree.createTree(trainMat,ops = (1,20))In [62]: yHat = regTree.createForeCast(myTree,testMat[:,0])    ...:In [63]: corrcoef(yHat, testMat[:,1], rowvar=0)[0, 1]    ...:Out[63]: 0.96408523182221506
#模型树In [17]: myTree = regTree.createTree(trainMat,regTree.moudleLeaf, regTree.moudleErr, ops = (1,20))    ...:In [19]: yHat = regTree.createForeCast(myTree, testMat[:,0], regTree.modelTreeEval)    ...:In [20]: corrcoef(yHat, testMat[:,1], rowvar=0)[0, 1]    ...:Out[20]: 0.97604121913806097In [28]: plt.plot(trainMat[:,0],trainMat[:,1],'ro')    ...: plt.show()

这里写图片描述

#标准线性回归In [26]: ws, X, Y = regTree.linearSolve(trainMat)    ...:In [27]: wsOut[27]:matrix([[ 37.58916794],[ 6.18978355]])In [29]: for i in range(np.shape(testMat)[0]):    ...: yHat[i] = testMat[i,0]*ws[1,0] + ws[0,0]    ...:In [30]: corrcoef(yHat, testMat[:,1], rowvar=0)[0, 1]    ...:Out[30]: 0.94346842356747662

R^2越接近1.0越好,由此看到在这里模型树的结果比回归树好。简单的线性模型的结果也没有上述的两种结果好。

9.7 使用python的Tkinter库创建GUI
机器学习总能给我们从未知的数据中抽取有用的信息。因此,能将这些信息可视化出来十分重要。
有一个同时支持数据呈现和用户交互的方式就是构建一个图形用户界面(GUI , grahical user interface)。

9.7.1 用Tkinter创建GUI
python的一个GUI框架是其中一个Tkinter,是随着python的标准编译版本发布的。
Tkinter的GUI由一些小部件组成,也就是文本框(Text Box),按钮(Button),标签(Label)和复选按钮(Check Button)等对象。

一个hello world的例子

In [10]: from Tkinter import *In [11]: root = Tk()In [13]: myLabel = Label(root, text = "Hello World")In [14]: myLabel.grid()In [15]: root.mainloop()

这里写图片描述

##9-6 用于构建树管理器界面的Tkinter小部件from numpy import *from Tkinter import *import regTreedef reDraw(tolS, tolN):    passdef drawNewTree():    passroot = Tk()Label(root,text = "Plot Place Holder").grid(row = 0, columnspan = 3)Label(root,text = "tolN").grid(row = 1, column = 0)tolNentry = Entry(root)tolNentry.grid(row = 1, column = 1)tolNentry.insert(0 ,'10')Label(root,text = "tolS").grid(row = 2, column = 0)tolSentry = Entry(root)tolSentry.grid(row = 2, column = 1)tolSentry.insert(0 ,'1.0')Button(root, text = "ReDraw", command = drawNewTree).grid(row=1,column=2,rowspan=3)chkBtnVar = IntVar()chkBtn = Checkbutton(root, text = "Model Tree", variable = chkBtnVar)chkBtn.grid(row = 3, column = 0, columnspan = 2)reDraw.rawDat = np.mat(regTree.loadDataSet(r'E:\ML\ML_source_code\mlia\Ch09\sine.txt'))reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:, 0]), 0.01)reDraw(1.0, 10)root.mainloop()

运行代码就会得到下图:
这里写图片描述

9.7.2 集成motplotlib 和Tkinter
通过修改motplotlib 的后端,达到在Tinker的GUI上绘图的目的。

matplotlib的构建程序包含一个前段,就是面向用户的一些代码,如plot()和scatter()等。它也同时也创建了一个后端。通过改变后端可以将图像绘制在PNG,PDF,SVG等格式的文件上。
下面将设置后端为TkAgg,TkAgg可以在所选的框架上调用Agg,把Agg呈现在画布上。并调用 .grid()来调整布局。

##Matplotlib和Tkinter的代码集成import matplotlibmatplotlib.use('TkAgg')                from matplotlib.backends.backend_tkagg import FigureCanvasTkAggfrom matplotlib.figure import Figuredef reDraw(tolS, tolN):    reDraw.f.clf()    #清空之前的图像    reDraw.a = reDraw.f.add_subplot(111)  #重新添加新图    if chkBtnVar.get():        if tolN < 2:            tolN =2        myTree = regTree.createTree(reDraw.rawDat, regTree.moudleLeaf,\                                    regTree.moudleErr, (tolS, tolN))        yHat = regTree.createForeCast(myTree, reDraw.testDat,\                                      regTree.modelTreeEval)    else:        myTree = regTree.createTree(reDraw.rawDat, ops = (tolS, tolN))        yHat = regTree.createForeCast(myTree, reDraw.testDat)    reDraw.a.scatter(reDraw.rawDat[:, 0], reDraw.rawDat[:, 1], s = 5)    reDraw.a.plot(reDraw.testDat, yHat, linewidth = 2.0)    reDraw.canvas.show()def getInputs():    try:        tolN = int(tolNentry.get())    except:        tolN = 10        print "enter Integer for tolN"        tolNentry.delete(0, END)        tolNentry.insert(0, '10')    try:        tolS = float(tolSentry.get())    except:        tolS =1.0        print "enter Float for tolS"        tolSentry.delete(0, END)        tolSentry.insert(0 ,'1.0')    return tolN,tolSdef drawNewTree():    tolN, tolS = getInputs()    reDraw(tolS, tolN)root = Tk()reDraw.f = Figure(figsize = (5,4) ,dpi = 100)  #创建画布reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master = root)reDraw.canvas.show()reDraw.canvas.get_tk_widget().grid(row = 0, columnspan = 3)
##9-6 用于构建树管理器界面的Tkinter小部件#==============================================================================# from numpy import *# import numpy as np# from Tkinter import *# import regTree# # def reDraw(tolS, tolN):#     pass# # def drawNewTree():#     pass#==============================================================================#Label(root,text = "Plot Place Holder").grid(row = 0, columnspan = 3)Label(root,text = "tolN").grid(row = 1, column = 0)tolNentry = Entry(root)tolNentry.grid(row = 1, column = 1)tolNentry.insert(0 ,'10')Label(root,text = "tolS").grid(row = 2, column = 0)tolSentry = Entry(root)tolSentry.grid(row = 2, column = 1)tolSentry.insert(0 ,'1.0')Button(root, text = "ReDraw", command = drawNewTree).grid(row=1,column=2,rowspan=3)chkBtnVar = IntVar()chkBtn = Checkbutton(root, text = "Model Tree", variable = chkBtnVar)chkBtn.grid(row = 3, column = 0, columnspan = 2)reDraw.rawDat = np.mat(regTree.loadDataSet(r'E:\ML\ML_source_code\mlia\Ch09\sine.txt'))reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:, 0]), 0.01)reDraw(1.0, 10)root.mainloop()

分类回归树(离散型)
这里写图片描述

模型树(连续型)
这里写图片描述

9.8小结
数据集中有一些复杂的关系,使输入数据和目标变量之间呈现非线性相关的关系。一般采用树结构来进行建模。 若叶节点使用模型是分段常数则称之为回归树,若叶节点使用的模型是线性回归方程称之为模型树。

原创粉丝点击