机器学习实战-CART分类回归树
来源:互联网 发布:淘宝店铺心得 编辑:程序博客网 时间:2024/05/16 06:33
树回归
虽然线性回归有强大的功能,但是在遇到数据具有很多特征时且特征之间具有复杂的关系时,构建全局的模型就显得比较难,而且也比较笨重,而且实际中处理的数据一般都是非线性的,不可能用全局线性模型来拟合任何数据。
一种可行的方法就是将数据集切分成很多易建模的数据,首次切分后难以拟合就继续切分,在这种切分方式下,树结构和回归法就相当有用。
9.1 复杂数据的 局部建模
树回归
优点:可以对复杂和非线性的数据建模
缺点:结果不易理解
适用数据类型:数值型和标称型数据
决策树ID3的回顾:
ID3的做法是每次选择最佳的特征来分割数据,并按特征的所有值进行切分。一旦按照某个特征切分之后,该特征在之后的算法就不起作用,所以有观点认为这种切分方法过于迅速。
另一种方法是二元切分法,即把每次数据切分成两份,如果数据的某特征值等于切分所要求的值,那么这些数据就进入树的左子树,反之进入右子树。
另外ID3的一个缺点是,在处理连续型数据时,需要事先将数据的连续型特征转换为离散型,这种转换会破坏连续型数据的内在特征。
而二元切分法则易于对树构建过程中进行调整以处理连续型特征。具体方法是特征值大于给定值就走左子树,否则走右子树。
树回归的一般方法
(1)收集数据:有数据可做就行
(2)准备数据:需要数值型的数据,标称型的数据应该映射成二值型数据
(3)分析数据:绘出数据的二维可视化显示结果,以字典的方式生成树
(4)训练算法:大部分时间花在叶节点树模型的构建上
(5)测试算法:使用测试数据,分析模型效果
(6)使用算法:使用该算法做其他事情
9.2 连续和离散型特征的树的构建
使用字典来存储树的数据结构,该字典包括以下4个元素。
待切分的特征
待切分的特征值
右子树,当不在需要切分的时候,也可以是一个单个值
左子树,与右子树类似
from numpy import *
#加载数据集函数
def loadfile(filename):
dataMat=[]
fr=open(filename)
for line in fr.readlines():
curline=line.strip().split('\t')
floatline=map(float,curline)#把curline列表中的元素都经过float操作
dataMat.append(floatline)
return dataMat
#二分切分数据集函数
def binSplitDataSet(dataset,feature,value):
matGreater=dataset[nonzero(dataset[:,feature]>value)[0],:]
matLess= dataset[nonzero(dataset[:, feature] <=value)[0], :]
return matGreater,matLess
#测试一下这两个函数
testmat=mat(eye(4))
a,b=binSplitDataSet(testmat,1,0.5)
print a
print b
构建树的函数为createTree()
其大致的伪代码如下:
找到最佳的切分特征:
如果该节点不能再分,该节点存为叶节点
执行二元切分
在右子树调用createTree()方法
在左子树调用createTree()方法
下面给出《机器学习实战》中的代码,注意不能照搬,书上有些代码有错,已改正。其中dataset必须是array数组形式,另外有些地方不需要加[0],自己敲代码时候注意一下就行了。
#求数据集的均值,即叶节点的值,作为回归值,详见《统计学习方法》
def regLeaf(dataSet):
return mean(dataSet[:,-1])
#求数据集的总方差,用来衡量数据集的混合度
def regErr(dataSet):
return var(dataSet[:,-1])*shape(dataSet)[0]
def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4 )):
#errLimit是切分前后误差下降限制
#numLimit是切分后的子集的样本数目限制
errLimit=ops[0];numLimit=ops[1]
#第一个停止条件,所有的y值是相同的
dataSet=array(dataSet)
#print 'this',array(dataSet)[:,-1]
if len(set((dataSet[:,-1].T).tolist()))==1:
return None,leafType(dataSet)
m,n=shape(dataSet)
S=regErr(dataSet)
bestS=inf;bestIndex=0;bestValue=0
for featureIndex in range(n-1):
for splitValue in set(dataSet[:,featureIndex]):
Greater,Less=binSplitDataSet(dataSet,featureIndex,splitValue)
if shape(Greater)[0]<numLimit or shape(Less)[0]<numLimit:
continue
newS=errType(Greater)+errType(Less)
if newS<bestS:
bestIndex=featureIndex
bestValue=splitValue
bestS=newS
#第二个停止条件 划分前后的误差要大于误差的下降值
if S-bestS<errLimit:
return None,leafType(dataSet)
#第三个停止条件,切分后的两个子数据集的样本数要大于numLimit
Greater,Less=binSplitDataSet(dataSet,bestIndex,bestValue)
if shape(Greater)[0]<numLimit or shape(Less)[0]<numLimit:
return None,leafType(dataSet)
return bestIndex,bestValue
在构建树的代码中,有一个参数是ops,它是一个元素,包括tolS,指切分前后误差的差的允许值,即两次误差的差必须大于这个限制值;另外一个是tolN,表示切分之后的两个子集的样本数必须大于这个值。这个ops元组的设置,相当于,对树的构建过程中的限制,可以理解为对树的预剪枝,但是这个ops的设置(限制条件)对误差的数量级十分铭感,不同的ops值的设置,将会得到差别很大的树,并且回归的值的误差将会是数量级的变化。所以《机器学习实战》讲解了如何进行后剪枝。下面给出后剪枝的代码,关于后剪枝以及代码的解释都在代码中附上了。#判断是不是树def isTree(obj): return (type(obj).__name__=='dict')#def getMean(tree): if isTree(tree['right']): tree['right']=getMean(tree['right']) if isTree(tree['left']): tree['left']=getMean(tree['left']) return (tree['right']+tree['left'])/2.0#剪枝函数def prune(tree,testData): #参数是生成的树,和测试的数据集 #如果测试数据个数为0,则对每个树(子树)返回树的平均值 if shape(testData)[0]==0:return getMean(tree) if (isTree(tree['right'])) or isTree(tree['left']):#是一颗树 lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal']) #分别对左右子树进行后剪枝 if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet) if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet) #如果两个分支不在是子树 if not isTree(tree['left']) and not isTree(tree['right']): lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal']) #没有合并的误差 errorNoMerge=sum(power(lSet[:,-1]-tree['left'],2))+\ sum(power(rSet[:,-1]-tree['right'],2)) #求合并后的误差 treeMean=(tree['left']+tree['right'])/2.0 errorMerge=sum(power(testData[:,-1]-treeMean,2)) #比较前后的误差,决定是否剪枝 if errorMerge<errorNoMerge: print "merging" return treeMean#合并后的叶节点为两者平均值 else:return tree return tree#此时已经从头剪枝到尾部了,返回treemyDataset2test=loadfile('ex2test.txt')myData2=loadfile('ex2.txt')myTree=createTree(myData2,ops=(0,1))#用测试数据集对ex2.txt的树进行剪枝print prune(myTree,myDataset2test)
模型树用树对数据进行建模,除了把叶节点简单地设定为常数值意外,还可以把叶节点设定为分段性函数,分段线性是指模型由多个线性片段组成。直接上代码吧!'''把节点设置为分段线性函数,分段线性是指模型由多个线性片段组成'''def linearSolve(dataSet):#构建模型 dataMat=mat(dataSet) m,n=shape(dataMat) X=mat(ones((m,n)));Y=mat(ones((m,1))) X[:,1:n]=dataMat[:,0:n-1]#之所以X[:,1:n]从1列开始是因为有个常数项b Y=dataMat[:,-1] XTX=X.T*X if linalg.det(XTX)==0.0: raise NameError('this matrix cannot do inverse!') ws=XTX.I*(X.T*Y) return ws,X,Ydef modelLeaf(dataSet):#返回线性模型的权重 ws,X,Y=linearSolve(dataSet) return wsdef modelErr(dataSet):#返回误差 ws,X,Y=linearSolve(dataSet) yHat=X*ws return sum(power(Y-yHat,2))'''modelLeaf与modelErr这两个函数是用来生成叶节点的,不过生成的不是一个值而是一个线性模型;同理modelErr是用来计算误差的,这两个函数调用时,都会在里面调用linearModel函数,因为需要用划分的子数据集生成线性模型''''''下面只需要在createTree()函数中替换leafType,errType这两个函数型参数'''myMat2=loadfile('exp2.txt')print createTree(myMat2,modelLeaf,modelErr,(1,10))总结:数据集经常包括一些复杂的关系,使得输入数据与目标变量之间呈现非线性的关系。一般可以采用树结构来进行数据的预测,若叶节点是分段常数,则称为回归树;若叶节点设定为线性回归方程,则称为模型树。对于树的剪枝操作一般有两种: 预剪枝(在树的构建过程中进行剪枝) 后剪枝(在树构建完成后进行剪枝)一般可以两者配合使用。
1 0
- 机器学习实战-CART分类回归树
- 机器学习算法-分类回归树CART
- 机器学习之分类回归树CART
- 机器学习算法之CART(分类回归树)概要
- 机器学习算法之CART(分类和回归树)
- 机器学习笔记十二:分类与回归树CART
- 分类与回归树(CART)- 机器学习ML
- 机器学习实战 -ch09.树回归(CART算法)
- py2.7 : 《机器学习实战》树回归 3.8号 CART算法用于回归
- 分类回归树CART
- CART分类回归树
- CART-分类回归树
- CART分类回归树
- 机器学习实战之数回归,CART算法
- 机器学习实战--CART
- 机器学习经典算法详解及Python实现--CART分类决策树、回归树和模型树
- 机器学习经典算法详解及Python实现--CART分类决策树、回归树和模型树
- 机器学习十大算法之-CART分类决策树、回归树和模型树
- 解决 pyspider的 css_selector_helper 无法使用
- Debian 8安装、配置
- Qt相关概念
- 5. Longest Palindromic Substring
- 移位溢注:告别依靠人品的偏移注入
- 机器学习实战-CART分类回归树
- java编程思想 -- 流程控制语句
- A build only device cannot be used to run this target.----献给新手
- Java基础知识(二)
- 1759:最长上升子序列
- dbcp连接池基本参数解释
- spring框架搭建与入门案例
- Centos6.5安装与配置JDK-8
- 更新Android studio gradle需要注意的地方