CART 回归树代码实现

来源:互联网 发布:云教学平台为您优化 编辑:程序博客网 时间:2024/05/18 17:04
from numpy import *def loadData(fileName):    retMat=[]    fr=open(fileName)    for line in fr.readlines():        curline=line.strip().split('\t')        curline=list(map(float,curline))        retMat.append(curline)    return mat(retMat)#树节点信息class treeNode():    def __init__(self,feat,value,leftChild=None,rightChild=None):        self.feat=feat        self.value=value        self.lc=leftChild        self.rc=rightChild#树节点,后面打算用先序遍历来打印树
#建立回归树,把框架写出来def splitData(dataSet,feat,val):    mat0=[]    mat1=[]    n=shape(data)[0]    for  j in range(n):        if(data[j,feat]>val):            mat0.append([data[j,0],data[j,1]])        else:            mat1.append([data[j,0],data[j,1]])    return mat(mat0),mat(mat1)def createTree(dataSet):    #计算最佳切分特征及切分点,用函数写出来    feat,val=chooseBestSplit(dataSet)    if feat==None:return treeNode(None,val)    node=treeNode(feat,val)    leftMat,rightMat=splitData(dataSet,feat,val)    node.lc=createTree(leftMat)    node.rc=createTree(rightMat)    return node    def calcValue(dataSet):    return mean(dataSet[:,-1])def calcError(dataSet):    m=shape(dataSet)[0]    return var(dataSet[:,-1])*m #树的节点应该保存成为平均值  
def chooseBestSplit(dataSet,op=(1,4)):    m,n=shape(dataSet)    tolS=op[0]    tolN=op[1]    #判断这个是不是已分为一个类    if len(set(dataSet[:,-1].T.tolist()[0]))==1:        return None,calcValue(dataSet)    S=calcError(dataSet)    bestS=inf    bestIndex=0    bestVal=0    for featIndex in range(n-1):        for splitVal in dataSet[:,featIndex]:            mat0,mat1=splitData(dataSet,featIndex,splitVal)            if shape(mat0)[0]<tolN or shape(mat1)[0]<tolN: continue            newS=calcError(mat0)+calcError(mat1)            if newS<bestS:                bestIndex=featIndex                bestValue=splitVal                bestS=newS    if S-bestS<tolS:        return None,calcValue(dataSet)    mat0,mat1=splitData(dataSet,bestIndex,bestValue)   # print(shape(mat0)[0],shape(mat1)[0])    if shape(mat0)[0]<tolN or shape(mat1)[0]<tolN:        return None,calcValue(dataSet)    return bestIndex,bestValue  #core code 
data=loadData('ex00.txt')node=createTree(data)#树的先序遍历def tree(node):    if node==None:        return    print(node.feat,":",node.value)    tree(node.lc)    tree(node.rc)    return tree(node) 
原创粉丝点击