《机器学习实战》笔记之九——树回归

来源:互联网 发布:mac保存到指定文件夹 编辑:程序博客网 时间:2024/05/17 04:58

第九章 树回归

  • CART算法
  • 回归与模型树
  • 树减枝算法
  • python中GUI的使用

线性回归需要拟合所有的样本点(局部加权线性回归除外),当数据拥有众多特征并且特征之间关系十分复杂时,就不可能使用全局线性模型来拟合任何数据。

将数据集切分成很多份易建模的数据,再用线性回归技术来建模可破。

本章介绍CART(Classification And Regression Trees, 分类回归树)的树构建算法,可用于分类还可用于回归。


9.1 复杂数据的局部性建

chap3的决策树主要是不断将数据切分成小数据集,直到所有目标变量完全相同,或者数据不能再切分为止。决策树是一种贪心算法,并不考虑能否达到全局最优。其构建算法ID3算法每次选取当前最佳的特征来分割数据,并按照该特征的所有可能取值来划分,之后该特征不会再起作用。另外一种方法是二元切分法,每次把数据集切成两份,如果数据的某特征等于切分所要求的值,那么这些数据就进入树的左子树,反之右子树。二元切分法可处理连续型特征,节省树的构建时间。

CART使用二元切分来处理连续型变量,应用广泛。


9.2 连续型和离散型特征的树的构建

与chap3类似,用字典存储树的结构。包括4元素:

  • 待切分的特征
  • 待切分的特征值
  • 左子树。当不再需要切分的时候,可是个单个值。
  • 右子树。类似左子树。

CART算法可固定树的数据结构,树包含左键和右键,可以存储 另一颗子树或者单个值。

伪代码:

找到最佳的待切分特征:

如果该节点不能再分,将该节点存为叶节点

执行二元切分

在右子树调用createTree()方法

在左子树调用createTree()方法

coding:

#!/usr/bin/env python# coding=utf-8from numpy import *def loadDataSet(fileName):    dataMat = []    fr      = open(fileName)    for line in fr.readlines():        curLine = line.strip().split("\n")        fltLine = map(float, curLine)       #将每行的每个元素映射为浮点数        dataMat.append(fltLine)    return dataMatdef binSplitDataSet(dataSet, feature,value):#数据集合,待切分的特征,特征值,将数据集合切分得到两个子集    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0], :][0]    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0], :][0]    return mat0, mat1def createTree(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)        #将数据集进行切分    if feat == None:        return val    else:        retTree = {}        retTree["spInd"] = feat        retTree["spVal"] = val        lSet, rSet       = binSplitDataSet(dataSet, feat, val)        retTree["left"]  = createTree(lSet, leafType, errType, ops)     #递归切分        retTree["right"] = createTree(rSet, leafType, errType, ops)        return retTreetestMat    = mat(eye(4))mat0, mat1 = binSplitDataSet(testMat, 1, 0.5)       #print testMatprint mat0print mat1

9.3 将CART算法用于回归

回归树假设叶节点是常数值。这种策略认为数据中的复杂关系可以用树结构来概括。

为成功构建以分段常数为叶节点的树,需要度量出数据的一致性。首先计算所有数据的均值,然后计算每条 数据的值到均值的差值,一般使用绝对值或平方差 来代替差值 ,类似方差计算,方差为平方误差的均值(均方差),这里需要计算平方误差的总值(总方差),均方差(var函数)乘以数据集样本点的个数可破。

构建树

chooseBestSplit()函数目标是找到数据集切分的最佳位置,遍历所有的特征及其可能的取值来找到使误差最小化的切分阈值。

伪代码:

对每个特征:

对每个特征值:

将数据集切分成两份

计算切分的误差

如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差 返回最佳切分的特征和阈值

数据:


Figure 9-1: 实验数据部分样本数据

coding:

#==============回归树的切分函数=============================def regLeaf(dataSet):    return mean(dataSet[:,-1])                      #返回叶节点,回归树中的目标变量的均值def regErr(dataSet):    return var(dataSet[:,-1])*shape(dataSet)[0]     #误差估计,计算目标变量的平方误差,需要返回总误差def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):    tolS = ops[0]                                   #容许的误差下降值    tolN = ops[1]                                   #切分的最少样本数    if len(set(dataSet[:,-1].T.tolist()[0]))==1:    #如果目标值相等,退出        return None, leafType(dataSet)    else:        m,n       = shape(dataSet)        S         = errType(dataSet)        bestS     = inf        bestIndex = 0        bestValue = 0        for featIndex in range(n-1):                            #对所有特征进行遍历            for splitVal in set(dataSet[:,featIndex]):          #遍历某个特征的所有特征值                mat0,mat1 = binSplitDataSet(dataSet, featIndex, splitVal)   #按照某个特征的某个值将数据切分成两个数据子集                if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):           #如果某个子集行数不大于tolN,也不应该切分                    continue                newS = errType(mat0) + errType(mat1)            #新误差由切分后的两个数据子集组成的误差                if newS < bestS:                                #判断新切分能否降低误差,                    bestIndex = featIndex                    bestValue = splitVal                    bestS     = newS        if (S - bestS) <tolS:                                   #如果误差不大则退出            return None, leafType(dataSet)        mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)        if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):       #如果切分出的数据集很小则退出            return None, leafType(dataSet)        return bestIndex, bestValue        #=====================================================================myDat = loadDataSet("ex00.txt")myMat = mat(myDat)print createTree(myMat)myDat1 = loadDataSet("ex0.txt")myMat1 = mat(myDat1)print createTree(myMat1)

切分效果:

Figure 9-2: 切分效果

9.4 树剪枝

一棵树如果节点过多,表明该模型可能对数据进行了”过拟合“。可使用测试集上交叉验证来发现过拟合。

剪枝(pruning):降低决策树的复杂度来避免过拟合。分为预剪枝(prepruning)和后剪枝(postpruning),后者需要使用训练集和测试集。

预剪枝

树构建算法对输入参数tolS和tolN非常敏感,通过不断地修改停止条件来得到合理结果并不是很好的办法。

后剪枝

后剪枝需要使用测试集。首先指定参数,使得构建出的树足够大,足够复杂,便于剪枝。从上而下找到叶节点,用测试集来判断这些叶节点合并是否能降低测试误差。是则合并。

伪代码:

基于已有的树切分测试数据:

如果存在任一子集是一棵树,则在该子集递归剪枝过程

计算将当前两个叶节点合并后的误差

计算不合并的误差

如果合并会降低误差的化,就将叶节点合并

coding:

#=============回归树剪枝函数==========================================def isTree(obj):                            #测试输入变量是否是一棵树,用于判断当前处理的节点是否是叶节点    return (type(obj).__name__=="dict")def getMean(tree):          #递归从上到下,到叶节点,找到两个叶节点则计算它们的平均值    if isTree(tree["right"]):        tree["right"] = getMean(tree["right"])    if isTree(tree["left"]):        tree["left"]   = getMean(tree["left"])    return (tree["left"]+tree["right"])/2.0def prune(tree, testData):    if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree    if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)    if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)    #if they are now both leafs, see if we can merge them    if not isTree(tree['left']) and not isTree(tree['right']):        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\            sum(power(rSet[:,-1] - tree['right'],2))        treeMean = (tree['left']+tree['right'])/2.0        errorMerge = sum(power(testData[:,-1] - treeMean,2))        if errorMerge < errorNoMerge:             print "merging"            return treeMean        else: return tree    else: return treemyTree = createTree(myMat2, ops=(0,1))myDatTest = loadDataSet("ex2test.txt")myMat2Test = mat(myDatTest)print prune(myTree, myMat2Test)

9.5 模型树

用树来对数据建模,除了把叶节点设定为常数值外,还可以将其设定为分段线性函数,分段线性(piecewise linear)即模型由多个线性片段组成。

Figure 9-3: 用来测试模型树构建函数的分段线性数据

可以设计两条分别从0.0-0.3、从0.3~1.0的直线,得到两个线性模型,即分段线性模型。

两条直线比很多节点组成一颗大树更容易理解。模型树的可解释性是它优于回归树的特点之一。模型树也具有更高的预测准确度。利用树生成算法对数据进行切分,且每份切分数据都能很容易被线性模型所表示,关键在于找到最佳切分。

coding:


#==========模型树的叶节点生成函数=========def linearSolve(dataSet):           #执行简单的线性回归    m,n       = shape(dataSet)    X         = mat(ones((m,n)))    Y         = mat(ones((m,1)))    X[:, 1:n] = dataSet[:, 0:n-1]    Y         = dataSet[:, -1]      #将X,Y中的数据格式化    xTx       = X.T*X    if linalg.det(xTx) == 0.0:        raise NameError("This matrix is singular, cannot do inverse")        ws = linalg.pinv(xTx)*(X.T*Y)    ws = xTx.I*(X.T*Y)    return ws, X, Ydef modelLeaf(dataSet):             #当数据不再需要切分的时候它负责生成叶节点模型    ws, X, Y = linearSolve(dataSet)    return wsdef modelErr(dataSet):                  ws, X, Y = linearSolve(dataSet)    yHat     = X*ws    return sum(power(Y - yHat,2))       #计算平方误差myMat2 = mat(loadDataSet("exp2.txt"))#print myMat2,type(myMat2)print createTree(myMat2, modelLeaf, modelErr, (1,10))

效果:

Figure 9-4: 切分结果

两个线性模型分别为y=3.468+1.185x和y=0.00168+11.964x,实际数据是由模型y=3.5+1.0x和y=0+12x再加上高斯噪音生成,可以看出效果还是不错。

9.6 示例:树回归与标准回归的比较

计算模型树、回归树及其他模型效果,比较客观的方法是计算相关系数,R*R值,Numpy中corrcoef(yHat, y, rowvar = 0)也即皮尔逊相关系数。

coding:

#================用树回归进行预测的代码=============def regTreeEval(model, inDat):    return float(model)def modelTreeEval(model, inDat):    n = shape(inDat)[1]    X = mat(ones((1,n+1)))    X[:,1:n+1] = inDat    return float(X*model)def treeForeCast(tree, inData, modelEval=regTreeEval):    if not isTree(tree):        return modelEval(tree, inData)          #如果输入单个数据或行向量,返回一个浮点值    else:        if inData[tree["spInd"]] > tree["spVal"]:            if isTree(tree["left"]):                return treeForeCast(tree["left"], inData, modelEval)            else:                return modelEval(tree["left"], inData)        else:            if isTree(tree["right"]):                return treeForeCast(tree["right"], inData, modelEval)            else:                return modelEval(tree["right"], inData)def createForeCast(tree, testData, modelEval=regTreeEval):    m    = len(testData)    yHat = mat(zeros((m,1)))    for i in range(m):        yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)#多次调用treeForeCast函数,将结果以列的形式放到yHat变量中    return yHattrainMat = mat(loadDataSet("bikeSpeedVsIq_train.txt"))testMat  = mat(loadDataSet("bikeSpeedVsIq_test.txt"))myTree   = createTree(trainMat, ops=(1,20))yHat     = createForeCast(myTree, testMat[:,0])print "回归树的皮尔逊相关系数:",corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]myTree   = createTree(trainMat, modelLeaf, modelErr,(1,20))yHat     = createForeCast(myTree, testMat[:,0], modelTreeEval)print "模型树的皮尔逊相关系数:",corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]ws, X, Y = linearSolve(trainMat)print "线性回归系数:",wsfor i in range(shape(testMat)[0]):    yHat[i] = testMat[i,0]*ws[1,0] + ws[0,0]print "线性回归模型的皮尔逊相关系数:",corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]

效果:

Figure 9-5: 回归树、模型树、简单线性回归的皮尔逊相关系数

可以看出模型树的结果比回归树效果好。


9.7 使用Python的Tkinter库创建GUI

本小结用到python的一个图形用户界面(GUI, Graphical User Interface)框架——Tkinter。

Tkinter的GUI由一些小部件组成(Widge)组成。小部件:文本框、按钮、标签和复选按钮等对象。

Figure 9-6: Tkinter的hello world

myLabel调用grid()方法时,把myLabel的位置告诉了布局管理器,grid()函数会把小部件安排在一个二维表格中。

coding:

#!/usr/bin/env python# coding=utf-8#用于构建树管理器界面的Tkinter小部件from numpy import *from Tkinter import *import regTreesdef reDraw(tolS, tolN):    passdef drawNewTree():    passroot = Tk()Label(root, text = "Plot Place Holder").grid(row=0, columnspan=3)    #设置文本,第0行,距0的行值为3,Label(root, text = "tolN").grid(row=1, column=0)tolNentry = Entry(root)                                              #Entry为允许单行文本输入的文本框,设置文本框,再定位置第1行第1列,再插入数值tolNentry.grid(row=1, column=1)tolNentry.insert(0,"10")Label(root, text="tolS").grid(row=2, column=0)tolSentry = Entry(root)tolSentry.grid(row=2, column=1)tolSentry.insert(0,"1.0")Button(root, text = "ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)#Botton按钮,设置第1行第2列,列值为3chkBtnVar = IntVar()                                                               #IntVar为按钮整数值小部件chkBtn    = Checkbutton(root, text = "Model Tree", variable = chkBtnVar)chkBtn.grid(row=3, column=0, columnspan = 2)reDraw.rawDat  = mat(regTrees.loadDataSet("sine.txt"))reDraw.testDat = arange(min(reDraw.rawDat[:,0]), max(reDraw.rawDat[:,0]),0.01)reDraw(1.0,10)root.mainloop()

效果:

Figure 9-7: 使用多个Tkinter部件创建的树管理器

集成Matplotlib和Tkinter

matplotlib绘制的图像可以放到GUI上。matplotlib构建程序时包含一个前端,如plot、scatter函数,也同时创建了一个后端,用于实现绘图和不同应用之间的接口,改变后端可以将图像绘制在PNG,PDF,SVG等格式的文件上。matplotlib将后端设置为TkAgg,TkAgg可以在所选GUI框架上调用Agg,把Agg呈现在画布上。

coding:

import matplotlibmatplotlib.use("TkAgg")                         #设定后端为TkAggfrom matplotlib.backends.backend_tkagg import FigureCanvasTkAggfrom matplotlib.figure import Figuredef reDraw(tolS,tolN):    reDraw.f.clf()                              #清空之前的图像    reDraw.a = reDraw.f.add_subplot(111)        #重新添加子图    if chkBtnVar.get():                         #检查复选框是否选中,确定是模型树还是回归树        if tolN<2:            tolN=2        myTree = regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf, regTrees.modelErr,(tolS,tolN))        yHat   = regTrees.createForeCast(myTree, reDraw.testDat, regTrees.modelTreeEval)    else:        myTree = regTrees.createTree(reDraw.rawDat, ops=(tolS,tolN))        yHat   = regTrees.createForeCast(myTree, reDraw.testDat)    reDraw.a.scatter(reDraw.rawDat[:,0], reDraw.rawDat[:,1],s=5)        #画真实值的散点图    reDraw.a.plot(reDraw.testDat,yHat,linewidth=2.0)                   #画预测值的直线图    reDraw.canvas.show()def getInputs():                                #获取用户输入的值,tolN期望得到整数值,tolS期望得到浮点数,    try:        tolN = int(tolNentry.get())             #在Entry部件调用get方法,    except:        tolN = 10        print "enter Integer for tolN"        tolNentry.delete(0,END)        tolNentry.insert(0,"10")    try:        tolS = float(tolSentry.get())    except:        tolS = 1.0        print "enter Integer for tolS"        tolSentry.delete(0,END)        tolSentry.insert(0,"1.0")    return tolN,tolSdef drawNewTree():                              #有人点击ReDraw按钮时就会调用该函数    tolN,tolS = getInputs()                     #得到输入框的值    reDraw(tolS,tolN)                           #调用reDraw函数root = Tk()reDraw.f = Figure(figsize=(5,4),dpi=100)reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)reDraw.canvas.show()reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

效果:

Figure 9-8: 用treeExplore的GUI构建的回归树


Figure 9-9: 模型树。参数为tolN=1,tolS=0

9.8小结

数据集中输入数据和目标变量呈非线性关系,可使用树结构来对预测值分段,包括分段常数或分段直线。叶节点使用分段常数则为回归树,若为线性回归方程则为模型树。


1 2
原创粉丝点击