决策树建模

来源:互联网 发布:wamp改mysql密码 编辑:程序博客网 时间:2024/05/24 00:06

一、数据结构

#决策树节点(包括划分条件、值、结果及子分支)class decisionnode:    def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):        self.col=col#判别对象或者划分条件        self.value=value#划分条件的分界值        self.results=results#分类结果        self.tb=tb#满足划分条件的子分支        self.fb=fb#不满足划分条件的子分支

二、基础操作

#根据特征的值对集合进行划分def devideset(rows,column,value):    split_function=None    if isinstance(value,int) or isinstance(value,float):        split_function=lambda row:row[column]>=value#如果特征值是数,则以大小划分    else:        split_function=lambda row:row[column]==value#如果特征值不是数,则以是否相等划分    set1=[row for row in rows if split_function(row)]    set2=[row for row in rows if not split_function(row)]    return (set1,set2)#返回划分的两个集合#对记录作分类字典统计def uniquecounts(rows):    results={}    for row in rows:        r=row[len(row)-1]        if r not in results:            results[r]=0        results[r]+=1    return results#基尼不纯度#随机放置的数据项出现于错误分类中的概率def giniimpurity(rows):    total=len(rows)    counts=uniquecounts(rows)    imp=0    for k1 in counts:        p1=float(counts[k1])/total        for k2 in counts:            if k1==k2:                continue            p2=float(counts[k2])/total            imp+=p1*p2    return imp#数据集的熵值def entropy(rows):    log2=lambda x:log(x)/log(2)    results=uniquecounts(rows)    ent=0.0    for r in results.keys():        p=float(results[r])/len(rows)        ent=ent-p*log2(p)    return ent#针对类型为数的数据项计算方差值def variance(rows):    if len(rows)==0:        return 0    data=[float(row[len(row)-1]) for row in rows]    mean=sum(data)/len(data)    variance=sum([(d-mean)**2 for d in data])/len(data)    return variance

三、构建决策树

#构建决策树def buildtree(rows,scoref=entropy):    if len(rows)==0:        return decisionnode()    current_score=scoref(rows)#初始信息熵值    best_gain=0.0#最佳信息增益    best_criteria=None#最佳划分条件    best_sets=None#最佳划分集合    column_count=len(rows[0])-1#划分指标(特征)    for col in range(0,column_count):        column_values={}        for row in rows:            column_values[row[col]]=1#统计划分指标的可能取值        for value in column_values.keys():            (set1,set2)=devideset(rows,col,value)#对划分指标的特定取值进行划分            p=float(len(set1))/len(rows)            gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)#计算信息增益            if gain>best_gain and len(set1)>0 and len(set2)>0:                best_gain=gain#更新最佳信息增益                best_criteria=(col,value)#更新最佳划分条件                best_sets=(set1,set2)#更新最佳划分集合    if best_gain>0:        trueBranch=buildtree(best_sets[0])#递归构建符合划分条件的子分支        falseBranch=buildtree(best_sets[1])#递归构建不符合划分条件的子分支        #构建决策树内部节点        return decisionnode(col=best_criteria[0],value=best_criteria[1],tb=trueBranch,fb=falseBranch)    else:        return decisionnode(results=uniquecounts(rows))#构建决策树叶节点#递归打印决策树def printtree(tree,indent=''):    if tree.results!=None:        print(str(tree.results))#打印叶节点的分类值    else:        print(str(tree.col)+':'+str(tree.value)+'?')#打印内部节点的分类条件及判别值        print(indent+'T->',end='')#end=''打印不换行!!!        printtree(tree.tb,indent+'  ')#递归打印符合分类条件的分支        print(indent+'F->',end='')        printtree(tree.fb,indent+'  ')#递归打印不符合分类条件的分支

四、预测

#决策树预测def classify(observation,tree):    if tree.results!=None:        return tree.results#到达叶节点返回分类值    else:        v=observation[tree.col]#对象在划分条件下的实际取值        branch=None        #按照划分条件及值类型选择下个分支        if isinstance(v,int) or isinstance(v,float):            if v>=tree.value:                branch=tree.tb            else:                branch=tree.fb        else:            if v==tree.value:                branch=tree.tb            else:                branch=tree.fb        return classify(observation,branch)#递归分类向下走#增加了对数据缺失情况的处理def mdclassfy(observation,tree):    if tree.results!=None:        return tree.results    else:        v=observation[tree.col]        if v==None:            #如果数据不存在,则两个分支都有可能,因而顺着两条分支分别走下去,然后回来合并两种情况            tr,fr=mdclassfy(observation,tree.tb),mdclassfy(observation,tree.fb)            tcount=sum(tr.values())            fcount=sum(fr.values())            tw=float(tcount)/(tcount+fcount)#根据两条分支下的结果数目设置影响结果的权值            fw=float(fcount)/(tcount+fcount)            result={}            #接下来对两个分支的预测结果合并到同一个result里面去            for k,v in tr.items():                result[k]=tw*v            for k,v in fr.items():                if k not in result:                    result[k]=0                result[k]+=v*fw            return result#返回的不是具体某个分类,而是多个具有不同可能性的分类结果集合        else:            if isinstance(v,int) or isinstance(v,float):                if v>=tree.value:                    branch=tree.tb                else:                    branch=tree.fb            else:                if v==tree.value:                    branch=tree.tb                else:                    branch=tree.fb            return mdclassfy(observation,branch)

五、剪枝

#为避免过拟合对决策树进行剪枝处理def prune(tree,mingain):    if tree.tb.results==None:        prune(tree.tb,mingain)    if tree.fb.results==None:        prune(tree.fb,mingain)    #从最底层的叶节点开始进行剪枝    if tree.tb.results!=None and tree.fb.results!=None:        tb,fb=[],[]        #从结果复原数据        for v,c in tree.tb.results.items():            tb+=[[v]]*c        for v,c in tree.fb.results.items():            fb+=[[v]]*c        delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)#剪枝合并前后的熵值变化量        if delta<mingain:#如果剪枝合并后增加的熵值小于设定的阈值,则进行剪枝处理            tree.tb,tree.fb=None,None            tree.results=uniquecounts(tb+fb)

六、测试

my_data=[['slashdot','USA','yes',18,'None'],        ['google','France','yes',23,'Premium'],        ['digg','USA','yes',24,'Basic'],        ['kiwitobes','France','yes',23,'Basic'],        ['google','UK','no',21,'Premium'],        ['(direct)','New Zealand','no',12,'None'],        ['(direct)','UK','no',21,'Basic'],        ['google','USA','no',24,'Premium'],        ['slashdot','France','yes',19,'None'],        ['digg','USA','no',18,'None'],        ['google','UK','no',18,'None'],        ['kiwitobes','UK','no',19,'None'],        ['digg','New Zealand','yes',12,'Basic'],        ['slashdot','UK','no',21,'None'],        ['google','UK','yes',18,'Basic'],        ['kiwitobes','France','yes',19,'Basic']]tree=buildtree(my_data)printtree(tree)print(classify([' (direct)','USA','yes',5],tree))print(mdclassfy(['google',None,'yes',None],tree))print(mdclassfy(['google','France',None,None],tree))print(' ')prune(tree,0.1)printtree(tree)print(' ')prune(tree,1)printtree(tree)print(classify([' (direct)','USA','yes',5],tree))print(mdclassfy(['google',None,'yes',None],tree))print(mdclassfy(['google','France',None,None],tree))

输出

0:google?T->3:21?  T->{'Premium': 3}  F->2:yes?    T->{'Basic': 1}    F->{'None': 1}F->0:slashdot?  T->{'None': 3}  F->2:yes?    T->{'Basic': 4}    F->3:21?      T->{'Basic': 1}      F->{'None': 3}{'Basic': 4}{'Premium': 2.25, 'Basic': 0.25}{'None': 0.125, 'Premium': 2.25, 'Basic': 0.125}0:google?T->3:21?  T->{'Premium': 3}  F->2:yes?    T->{'Basic': 1}    F->{'None': 1}F->0:slashdot?  T->{'None': 3}  F->2:yes?    T->{'Basic': 4}    F->3:21?      T->{'Basic': 1}      F->{'None': 3}0:google?T->3:21?  T->{'Premium': 3}  F->2:yes?    T->{'Basic': 1}    F->{'None': 1}F->{'None': 6, 'Basic': 5}{'None': 6, 'Basic': 5}{'Premium': 2.25, 'Basic': 0.25}{'None': 0.125, 'Premium': 2.25, 'Basic': 0.125}
0 0
原创粉丝点击