决策树Python实现

来源:互联网 发布:苹果网络电话软件 编辑:程序博客网 时间:2024/06/06 10:39

Python实现一

在这里我们先调用sklearn算法包中的接口,看一下算法的效果。

实验数据(可能你并不陌生~~~):

1.5 50 thin1.5 60 fat1.6 40 thin1.6 60 fat1.7 60 thin1.7 80 fat1.8 60 thin1.8 90 fat1.9 70 thin1.9 80 fat1.9 92 fat1.6 50 fat1.65 50 thin1.68 48 thin1.7 60 thin1.7 50 thin1.7 65 fat

算法调用:

# _*_ encoding:utf-8 _*_import numpy as npimport scipy as spfrom sklearn import treefrom sklearn.metrics import precision_recall_curvefrom sklearn.metrics import classification_reportfrom sklearn.cross_validation import train_test_split#####载入数据######data=[]labels=[]f=open('./data/1.txt')for line in f:    #print line    tokens=line.strip().split(' ')    #print tokens    data.append([float(tk) for tk in tokens[:-1] ])    #print data    labels.append(tokens[-1])    #print labelsdata=np.array(data)# print data# print data.shape#(10,2)labels=np.array(labels)# print labels# print labels.shape#(10,)#########转换标签为0/1###########label=np.zeros(labels.shape)label[labels=='fat']=1# print label##########拆分成训练数据集合测试数据集############data_train,data_test,label_train,label_test=train_test_split(data,label,test_size=0.2,random_state=0)####使用信息熵作为划分标准,对决策树进行训练#######clf=tree.DecisionTreeClassifier(criterion='entropy')'''criterion : string, optional (default="gini")The function to measure the quality of a split. Supported criteria are"gini" for the Gini impurity and "entropy" for the information gain.'''print clfclf.fit(data_train,label_train)##########把决策树结构写入文件###########f=open('tree.dot','w')f=tree.export_graphviz(clf,out_file=f)####系数反映每个特征的影响力,越大表示该特征在分类中起到的作用越大####print clf.feature_importances_#########测试结果的打印##########answer=clf.predict(data_test)print data_testprint  answerprint label_testprint np.mean(answer==label_test)#############准确率与召回率##########precision,recall,threshold=precision_recall_curve(label_test,clf.predict(data_test))# print precision# print recall# print thresholdanswer=clf.predict(data)#将所有的数据进行测试print labelprint answerprint classification_report(label,answer,target_names=['thin','fat'])

Python 实现二

以下使用Python实现的比较简单的ID3算法,整个过程还是比较好理解的,在该代码中没有考虑标签和属性值缺失的情况,目的只是为了大家更好的理解算法的基本思想。

# -*- encoding:utf-8 -*-import operatorimport math#构造数据集def createDataSet():    dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]    features = ['A','B']    return dataSet,features#构造决策树def treeGrowth(dataSet,features):    classList=[example[-1] for example in dataSet]    if classList.count(classList[0])==len(classList):        return classList[0]    if len(dataSet[0])==1:#没有更多的特征        return classify(classList)    bestFeat=findBestSplit(dataSet)#bestFeat is the index of best feature    bestFeatLabel=features[bestFeat]    myTree={bestFeatLabel:{}}    featValues=[example[bestFeat] for example in dataSet]#选择出的最好的属性的取值    uniqueFeatValues=set(featValues)    del(features[bestFeat])#删除掉该属性    for values in uniqueFeatValues:        subDataSet=splitDataSet(dataSet,bestFeat,values)        myTree[bestFeatLabel][values]=treeGrowth(subDataSet,features)    return myTree#当没有多余的feature,并且剩下的样本不完全是一样的类别时,采用多数取决def classify(classList):    classCount={}    for vote in classList:        if vote not in classCount.keys():            classCount[vote]=0        classCount[vote]+=1    sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)    return sortedClassCount[0][0]#寻找用于分裂的最佳属性(遍历所有属性,计算信息增益,最大的为最好的分裂属性)def findBestSplit(dataSet):    numFeatures=len(dataSet[0])-1    baseEntropy=calcShannonEnt(dataSet)    bestInfoGain=0.0    bestFeat=-1    for i in range(numFeatures):        featValues=[example[i] for example in dataSet]        uniqueFeatValues=set(featValues)        newEntropy=0.0        for val in uniqueFeatValues:            subDataSet=splitDataSet(dataSet,i,val)            prob=len(subDataSet)/float(len(dataSet))            newEntropy+=prob*calcShannonEnt(subDataSet)        if (baseEntropy-newEntropy)>bestInfoGain:            bestInfoGain=baseEntropy-newEntropy            bestFeat=i    return bestFeat#选择完分裂属性以后,就进行数据集的分裂def splitDataSet(dataSet,feat,values):    retDataSet=[]    for featVec in dataSet:        if featVec[feat]==values:            reduceFeatVec=featVec[:feat]            reduceFeatVec.extend(featVec[feat+1:])            retDataSet.append(reduceFeatVec)    return retDataSet#计算数据集的熵def calcShannonEnt(dataSet):    numEntries=len(dataSet)    labelCounts={}    for featVec in dataSet:        currentLabel=featVec[-1]        if currentLabel not in labelCounts.keys():            labelCounts[currentLabel]=0        labelCounts[currentLabel]+=1    shannoEnt=0.0    for key in labelCounts:        prob=float(labelCounts[key])/numEntries        if prob!=0:            shannoEnt-=prob*math.log(prob,2)    return shannoEnt#根据上面构造的决策树进行数据分类def predict(tree,newObject):    while isinstance(tree,dict):        key=tree.keys()[0]        tree=tree[key][newObject[key]]    return treeif __name__ == '__main__':    dataSet,features=createDataSet()    tree=treeGrowth(dataSet,features)    print tree    print predict(tree,{'A':1,'B':1})

实验结果为:
最终构造的决策树为:{'A': {0: 'no', 1: {'B': {0: 'no', 1: 'yes'}}}}
测试结果为:yes

Python 实现三

以下是用Python实现的二叉决策树(分支只能是两个)方法,在该算法中你可以选择使用信息熵或者选择使用基尼指数来建立决策树,并且实现了可视化。

# -*- encoding:utf-8 -*-from math import logfrom PIL import Image,ImageDrawimport zlibmy_data=[['slashdot','USA','yes',18,'None'],['google','France','yes',23,'Premium'],['digg','USA','yes',24,'Basic'],['kiwitobes','France','yes',23,'Basic'],['google','UK','no',21,'Premium'],['(direct)','New Zealand','no',12,'None'],['(direct)','UK','no',21,'Basic'],['google','USA','no',24,'Premium'],['slashdot','France','yes',19,'None'],['digg','USA','no',18,'None'],['google','UK','no',18,'None'],['kiwitobes','UK','no',19,'None'],['digg','New Zealand','yes',12,'Basic'],['slashdot','UK','no',21,'None'],['google','UK','yes',18,'Basic'],['kiwitobes','France','yes',19,'Basic']]#创建决策节点class decidenode():    def __init__(self,col=-1,value=None,result=None,tb=None,fb=None):        self.col=col         #待检验的判断条件所对应的列索引值        self.value=value     #为了使结果为true,当前列要匹配的值        self.result=result   #叶子节点的值        self.tb=tb           #true下的节点        self.fb=fb           #false下的节点#对数值型和离散型数据进行分类def DivideSet(rows,column,value):    splitfunction=None    if isinstance(value,int) or isinstance(value,float):        splitfunction=lambda x :x>=value    else:        splitfunction=lambda x :x==value    set1=[row for row in rows if splitfunction(row[column])]    set2=[row for row in rows if not splitfunction(row[column])]    return (set1,set2)#计算数据所包含的实例个数def UniqueCount(rows):    result={}    for row in rows:        r=row[len(row)-1]        result.setdefault(r,0)        result[r]+=1    return result#计算Gini impuritydef GiniImpurity(rows):    total=len(rows)    counts=UniqueCount(rows)    imp=0    for k1 in counts:        p1=float(counts[k1])/total        for k2 in counts:            if k1==k2: continue            p2=float(counts[k2])/total            imp+=p1*p2    return imp#计算信息熵Entropydef entropy(rows):    log2=lambda x:log(x)/log(2)    results=UniqueCount(rows)    # Now calculate the entropy    ent=0.0    for r in results.keys( ):        p=float(results[r])/len(rows)        ent=ent-p*log2(p)    return ent#计算方差(当输出为连续型的时候,用方差来判断分类的好或坏,决策树两边分别是比较大的数和比较小的数)#可以通过后修剪来合并叶子节点def variance(rows):    if len(rows)==0:return 0    data=[row[len(rows)-1] for row in rows]    mean=sum(data)/len(data)    variance=sum([(d-mean)**2 for d in data])/len(data)    return variance###############################################################33#创建决策树递归def BuildTree(rows,judge=entropy):    if len(rows)==0:return decidenode()    #初始化值    best_gain=0    best_value=None    best_sets=None    best_col=None    S=judge(rows)    #获得最好的gain    for col in range(len(rows[0])-1):        total_value={}        for row in rows:            total_value[row[col]]=1        for value in total_value.keys():            (set1,set2)=DivideSet(rows,col,value)            #计算信息增益,将最好的保存下来            s1=float(len(set1))/len(rows)            s2=float(len(set2))/len(rows)            gain=S-s1*judge(set1)-s2*judge(set2)            if gain > best_gain:                best_gain=gain                best_value=value                best_col=col                best_sets=(set1,set2)    #创建节点    if best_gain>0:        truebranch=BuildTree(best_sets[0])        falsebranch=BuildTree(best_sets[1])        return decidenode(col=best_col,value=best_value,tb=truebranch,fb=falsebranch)    else:        return decidenode(result=UniqueCount(rows))#打印文本形式的treedef PrintTree(tree,indent=''):    if tree.result!=None:        print str(tree.result)    else:        print '%s:%s?' % (tree.col,tree.value)        print indent,'T->',        PrintTree(tree.tb,indent+'  ')        print indent,'F->',        PrintTree(tree.fb,indent+'  ')def getwidth(tree):    if tree.tb==None and tree.fb==None: return 1    return getwidth(tree.tb)+getwidth(tree.fb)def getdepth(tree):    if tree.tb==None and tree.fb==None: return 0    return max(getdepth(tree.tb),getdepth(tree.fb))+1#打印图表形式的treedef drawtree(tree,jpeg='tree.jpg'):    w=getwidth(tree)*100    h=getdepth(tree)*100+120    img=Image.new('RGB',(w,h),(255,255,255))    draw=ImageDraw.Draw(img)    drawnode(draw,tree,w/2,20)    img.save(jpeg,'JPEG')def drawnode(draw,tree,x,y):    if tree.result==None:    # Get the width of each branch        w1=getwidth(tree.fb)*100        w2=getwidth(tree.tb)*100        # Determine the total space required by this node        left=x-(w1+w2)/2        right=x+(w1+w2)/2        # Draw the condition string        draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))        # Draw links to the branches        draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))        draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))        # Draw the branch nodes        drawnode(draw,tree.fb,left+w1/2,y+100)        drawnode(draw,tree.tb,right-w2/2,y+100)    else:        txt=' \n'.join(['%s:%d'%v for v in tree.result.items( )])        draw.text((x-20,y),txt,(0,0,0))#对新实例进行查询def classify(observation,tree):    if tree.result!=None: return tree.result    else:        v=observation[tree.col]        branch=None        if isinstance(v,int) or isinstance(v,float):            if v>=tree.value:                branch=tree.tb            else:                branch=tree.fb        else:            if v==tree.value:                branch=tree.tb            else:                branch=tree.fb        return classify(observation,branch)#后剪枝,设定一个阈值mingain来后剪枝,当合并后熵增加的值小于原来的值,就合并def prune(tree,mingain):    if tree.tb.result==None:        prune(tree.tb,mingain)    if tree.fb.result==None:        prune(tree.fb,mingain)    if tree.tb.result!=None and tree.fb.result!=None:        tb1,fb1=[],[]        for v,c in tree.tb.result.items():            tb1+=[[v]]*c    #这里是为了跟row保持一样的格式,因为UniqueCount就是对这种进行的计算        for v,c in tree.fb.result.items():            fb1+=[[v]]*c        delta=entropy(tb1+fb1)-(entropy(tb1)+entropy(fb1)/2)        if delta<mingain:            tree.tb,tree.fb=None,None            tree.result=UniqueCount(tb1+fb1)#对缺失属性的数据进行查询def mdclassify(observation,tree):    if tree.result!=None:        return tree.result    if observation[tree.col]==None:        tb,fb=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb)        #这里的tr跟fr实际是这个函数返回的字典        tbcount=sum(tb.values())        fbcount=sum(fb.values())        tw=float(tbcount)/(tbcount+fbcount)        fw=float(fbcount)/(tbcount+fbcount)        result={}        for k,v in tb.items():            result.setdefault(k,0)            result[k]=v*tw        for k,v in fb.items():            result.setdefault(k,0)            result[k]=v*fw        return result    else:        v=observation[tree.col]        branch=None        if isinstance(v,int) or isinstance(v,float):            if v>=tree.value:                branch=tree.tb            else:                branch=tree.fb        else:            if v==tree.value:                branch=tree.tb            else:                branch=tree.fb        return mdclassify(observation,branch)def main():                  #以下内容为我测试决策树的代码    a=BuildTree(my_data)    print "生成的决策树为:\n"    PrintTree(a)    print '----'*15    drawtree(a,jpeg='treeview.jpg')    print "利用阈值0.1进行剪枝后,得到:\n"    prune(a,0.1)    PrintTree(a)    print '-----'*15    prune(a,1)    print "利用阈值1进行剪枝后,得到:\n"    PrintTree(a)    print '----'*15    mdclassify(['google','France',None,None],a)    print mdclassify(['google','France',None,None],a)    mdclassify(['google',None,'yes',None],a)    print mdclassify(['google',None,'yes',None],a)if __name__=='__main__':    main()

实验结果为:

这里写图片描述

《完》

所谓的不平凡就是平凡的N次幂。                        -------By Ada