决策树Python实现
来源:互联网 发布:苹果网络电话软件 编辑:程序博客网 时间:2024/06/06 10:39
Python实现一
在这里我们先调用sklearn算法包中的接口,看一下算法的效果。
实验数据(可能你并不陌生~~~):
1.5 50 thin1.5 60 fat1.6 40 thin1.6 60 fat1.7 60 thin1.7 80 fat1.8 60 thin1.8 90 fat1.9 70 thin1.9 80 fat1.9 92 fat1.6 50 fat1.65 50 thin1.68 48 thin1.7 60 thin1.7 50 thin1.7 65 fat
算法调用:
# _*_ encoding:utf-8 _*_import numpy as npimport scipy as spfrom sklearn import treefrom sklearn.metrics import precision_recall_curvefrom sklearn.metrics import classification_reportfrom sklearn.cross_validation import train_test_split#####载入数据######data=[]labels=[]f=open('./data/1.txt')for line in f: #print line tokens=line.strip().split(' ') #print tokens data.append([float(tk) for tk in tokens[:-1] ]) #print data labels.append(tokens[-1]) #print labelsdata=np.array(data)# print data# print data.shape#(10,2)labels=np.array(labels)# print labels# print labels.shape#(10,)#########转换标签为0/1###########label=np.zeros(labels.shape)label[labels=='fat']=1# print label##########拆分成训练数据集合测试数据集############data_train,data_test,label_train,label_test=train_test_split(data,label,test_size=0.2,random_state=0)####使用信息熵作为划分标准,对决策树进行训练#######clf=tree.DecisionTreeClassifier(criterion='entropy')'''criterion : string, optional (default="gini")The function to measure the quality of a split. Supported criteria are"gini" for the Gini impurity and "entropy" for the information gain.'''print clfclf.fit(data_train,label_train)##########把决策树结构写入文件###########f=open('tree.dot','w')f=tree.export_graphviz(clf,out_file=f)####系数反映每个特征的影响力,越大表示该特征在分类中起到的作用越大####print clf.feature_importances_#########测试结果的打印##########answer=clf.predict(data_test)print data_testprint answerprint label_testprint np.mean(answer==label_test)#############准确率与召回率##########precision,recall,threshold=precision_recall_curve(label_test,clf.predict(data_test))# print precision# print recall# print thresholdanswer=clf.predict(data)#将所有的数据进行测试print labelprint answerprint classification_report(label,answer,target_names=['thin','fat'])
Python 实现二
以下使用Python实现的比较简单的ID3算法,整个过程还是比较好理解的,在该代码中没有考虑标签和属性值缺失的情况,目的只是为了大家更好的理解算法的基本思想。
# -*- encoding:utf-8 -*-import operatorimport math#构造数据集def createDataSet(): dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']] features = ['A','B'] return dataSet,features#构造决策树def treeGrowth(dataSet,features): classList=[example[-1] for example in dataSet] if classList.count(classList[0])==len(classList): return classList[0] if len(dataSet[0])==1:#没有更多的特征 return classify(classList) bestFeat=findBestSplit(dataSet)#bestFeat is the index of best feature bestFeatLabel=features[bestFeat] myTree={bestFeatLabel:{}} featValues=[example[bestFeat] for example in dataSet]#选择出的最好的属性的取值 uniqueFeatValues=set(featValues) del(features[bestFeat])#删除掉该属性 for values in uniqueFeatValues: subDataSet=splitDataSet(dataSet,bestFeat,values) myTree[bestFeatLabel][values]=treeGrowth(subDataSet,features) return myTree#当没有多余的feature,并且剩下的样本不完全是一样的类别时,采用多数取决def classify(classList): classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote]=0 classCount[vote]+=1 sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0]#寻找用于分裂的最佳属性(遍历所有属性,计算信息增益,最大的为最好的分裂属性)def findBestSplit(dataSet): numFeatures=len(dataSet[0])-1 baseEntropy=calcShannonEnt(dataSet) bestInfoGain=0.0 bestFeat=-1 for i in range(numFeatures): featValues=[example[i] for example in dataSet] uniqueFeatValues=set(featValues) newEntropy=0.0 for val in uniqueFeatValues: subDataSet=splitDataSet(dataSet,i,val) prob=len(subDataSet)/float(len(dataSet)) newEntropy+=prob*calcShannonEnt(subDataSet) if (baseEntropy-newEntropy)>bestInfoGain: bestInfoGain=baseEntropy-newEntropy bestFeat=i return bestFeat#选择完分裂属性以后,就进行数据集的分裂def splitDataSet(dataSet,feat,values): retDataSet=[] for featVec in dataSet: if featVec[feat]==values: reduceFeatVec=featVec[:feat] reduceFeatVec.extend(featVec[feat+1:]) retDataSet.append(reduceFeatVec) return retDataSet#计算数据集的熵def calcShannonEnt(dataSet): numEntries=len(dataSet) labelCounts={} for featVec in dataSet: currentLabel=featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel]=0 labelCounts[currentLabel]+=1 shannoEnt=0.0 for key in labelCounts: prob=float(labelCounts[key])/numEntries if prob!=0: shannoEnt-=prob*math.log(prob,2) return shannoEnt#根据上面构造的决策树进行数据分类def predict(tree,newObject): while isinstance(tree,dict): key=tree.keys()[0] tree=tree[key][newObject[key]] return treeif __name__ == '__main__': dataSet,features=createDataSet() tree=treeGrowth(dataSet,features) print tree print predict(tree,{'A':1,'B':1})
实验结果为:
最终构造的决策树为:{'A': {0: 'no', 1: {'B': {0: 'no', 1: 'yes'}}}}
测试结果为:yes
Python 实现三
以下是用Python实现的二叉决策树(分支只能是两个)方法,在该算法中你可以选择使用信息熵或者选择使用基尼指数来建立决策树,并且实现了可视化。
# -*- encoding:utf-8 -*-from math import logfrom PIL import Image,ImageDrawimport zlibmy_data=[['slashdot','USA','yes',18,'None'],['google','France','yes',23,'Premium'],['digg','USA','yes',24,'Basic'],['kiwitobes','France','yes',23,'Basic'],['google','UK','no',21,'Premium'],['(direct)','New Zealand','no',12,'None'],['(direct)','UK','no',21,'Basic'],['google','USA','no',24,'Premium'],['slashdot','France','yes',19,'None'],['digg','USA','no',18,'None'],['google','UK','no',18,'None'],['kiwitobes','UK','no',19,'None'],['digg','New Zealand','yes',12,'Basic'],['slashdot','UK','no',21,'None'],['google','UK','yes',18,'Basic'],['kiwitobes','France','yes',19,'Basic']]#创建决策节点class decidenode(): def __init__(self,col=-1,value=None,result=None,tb=None,fb=None): self.col=col #待检验的判断条件所对应的列索引值 self.value=value #为了使结果为true,当前列要匹配的值 self.result=result #叶子节点的值 self.tb=tb #true下的节点 self.fb=fb #false下的节点#对数值型和离散型数据进行分类def DivideSet(rows,column,value): splitfunction=None if isinstance(value,int) or isinstance(value,float): splitfunction=lambda x :x>=value else: splitfunction=lambda x :x==value set1=[row for row in rows if splitfunction(row[column])] set2=[row for row in rows if not splitfunction(row[column])] return (set1,set2)#计算数据所包含的实例个数def UniqueCount(rows): result={} for row in rows: r=row[len(row)-1] result.setdefault(r,0) result[r]+=1 return result#计算Gini impuritydef GiniImpurity(rows): total=len(rows) counts=UniqueCount(rows) imp=0 for k1 in counts: p1=float(counts[k1])/total for k2 in counts: if k1==k2: continue p2=float(counts[k2])/total imp+=p1*p2 return imp#计算信息熵Entropydef entropy(rows): log2=lambda x:log(x)/log(2) results=UniqueCount(rows) # Now calculate the entropy ent=0.0 for r in results.keys( ): p=float(results[r])/len(rows) ent=ent-p*log2(p) return ent#计算方差(当输出为连续型的时候,用方差来判断分类的好或坏,决策树两边分别是比较大的数和比较小的数)#可以通过后修剪来合并叶子节点def variance(rows): if len(rows)==0:return 0 data=[row[len(rows)-1] for row in rows] mean=sum(data)/len(data) variance=sum([(d-mean)**2 for d in data])/len(data) return variance###############################################################33#创建决策树递归def BuildTree(rows,judge=entropy): if len(rows)==0:return decidenode() #初始化值 best_gain=0 best_value=None best_sets=None best_col=None S=judge(rows) #获得最好的gain for col in range(len(rows[0])-1): total_value={} for row in rows: total_value[row[col]]=1 for value in total_value.keys(): (set1,set2)=DivideSet(rows,col,value) #计算信息增益,将最好的保存下来 s1=float(len(set1))/len(rows) s2=float(len(set2))/len(rows) gain=S-s1*judge(set1)-s2*judge(set2) if gain > best_gain: best_gain=gain best_value=value best_col=col best_sets=(set1,set2) #创建节点 if best_gain>0: truebranch=BuildTree(best_sets[0]) falsebranch=BuildTree(best_sets[1]) return decidenode(col=best_col,value=best_value,tb=truebranch,fb=falsebranch) else: return decidenode(result=UniqueCount(rows))#打印文本形式的treedef PrintTree(tree,indent=''): if tree.result!=None: print str(tree.result) else: print '%s:%s?' % (tree.col,tree.value) print indent,'T->', PrintTree(tree.tb,indent+' ') print indent,'F->', PrintTree(tree.fb,indent+' ')def getwidth(tree): if tree.tb==None and tree.fb==None: return 1 return getwidth(tree.tb)+getwidth(tree.fb)def getdepth(tree): if tree.tb==None and tree.fb==None: return 0 return max(getdepth(tree.tb),getdepth(tree.fb))+1#打印图表形式的treedef drawtree(tree,jpeg='tree.jpg'): w=getwidth(tree)*100 h=getdepth(tree)*100+120 img=Image.new('RGB',(w,h),(255,255,255)) draw=ImageDraw.Draw(img) drawnode(draw,tree,w/2,20) img.save(jpeg,'JPEG')def drawnode(draw,tree,x,y): if tree.result==None: # Get the width of each branch w1=getwidth(tree.fb)*100 w2=getwidth(tree.tb)*100 # Determine the total space required by this node left=x-(w1+w2)/2 right=x+(w1+w2)/2 # Draw the condition string draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0)) # Draw links to the branches draw.line((x,y,left+w1/2,y+100),fill=(255,0,0)) draw.line((x,y,right-w2/2,y+100),fill=(255,0,0)) # Draw the branch nodes drawnode(draw,tree.fb,left+w1/2,y+100) drawnode(draw,tree.tb,right-w2/2,y+100) else: txt=' \n'.join(['%s:%d'%v for v in tree.result.items( )]) draw.text((x-20,y),txt,(0,0,0))#对新实例进行查询def classify(observation,tree): if tree.result!=None: return tree.result else: v=observation[tree.col] branch=None if isinstance(v,int) or isinstance(v,float): if v>=tree.value: branch=tree.tb else: branch=tree.fb else: if v==tree.value: branch=tree.tb else: branch=tree.fb return classify(observation,branch)#后剪枝,设定一个阈值mingain来后剪枝,当合并后熵增加的值小于原来的值,就合并def prune(tree,mingain): if tree.tb.result==None: prune(tree.tb,mingain) if tree.fb.result==None: prune(tree.fb,mingain) if tree.tb.result!=None and tree.fb.result!=None: tb1,fb1=[],[] for v,c in tree.tb.result.items(): tb1+=[[v]]*c #这里是为了跟row保持一样的格式,因为UniqueCount就是对这种进行的计算 for v,c in tree.fb.result.items(): fb1+=[[v]]*c delta=entropy(tb1+fb1)-(entropy(tb1)+entropy(fb1)/2) if delta<mingain: tree.tb,tree.fb=None,None tree.result=UniqueCount(tb1+fb1)#对缺失属性的数据进行查询def mdclassify(observation,tree): if tree.result!=None: return tree.result if observation[tree.col]==None: tb,fb=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb) #这里的tr跟fr实际是这个函数返回的字典 tbcount=sum(tb.values()) fbcount=sum(fb.values()) tw=float(tbcount)/(tbcount+fbcount) fw=float(fbcount)/(tbcount+fbcount) result={} for k,v in tb.items(): result.setdefault(k,0) result[k]=v*tw for k,v in fb.items(): result.setdefault(k,0) result[k]=v*fw return result else: v=observation[tree.col] branch=None if isinstance(v,int) or isinstance(v,float): if v>=tree.value: branch=tree.tb else: branch=tree.fb else: if v==tree.value: branch=tree.tb else: branch=tree.fb return mdclassify(observation,branch)def main(): #以下内容为我测试决策树的代码 a=BuildTree(my_data) print "生成的决策树为:\n" PrintTree(a) print '----'*15 drawtree(a,jpeg='treeview.jpg') print "利用阈值0.1进行剪枝后,得到:\n" prune(a,0.1) PrintTree(a) print '-----'*15 prune(a,1) print "利用阈值1进行剪枝后,得到:\n" PrintTree(a) print '----'*15 mdclassify(['google','France',None,None],a) print mdclassify(['google','France',None,None],a) mdclassify(['google',None,'yes',None],a) print mdclassify(['google',None,'yes',None],a)if __name__=='__main__': main()
实验结果为:
《完》
所谓的不平凡就是平凡的N次幂。 -------By Ada
阅读全文
0 0
- # 详解决策树、python实现决策树
- Python实现决策树算法
- 决策树--Python实现
- 决策树及其python实现
- 决策树原理-python实现
- Python实现决策树
- Python分类决策树实现
- 决策树的Python实现
- 决策树及其Python实现
- 决策树的python实现
- python实现决策树
- 决策树(Python实现)
- 决策树的python实现
- python实现决策树分类
- 决策树Python实现
- 决策树算法--python实现
- 决策树Python实现
- python 实现决策树画图
- selenium3.4 定位动态的iframe
- 云笔记项目 Unit03
- j2se项目如何打成可以运行Jar包
- HTML5 拖拽效果 解析
- 使用Eclipse构建Maven项目 (step-by-step)
- 决策树Python实现
- 模板的多态
- 华为机试:求最小公倍数、Ja题目2-3级(走格子)
- linux不能import caffe
- Error:Conflict with dependency 'com.google.code.findbugs:jsr305' in project ':app':报错解决
- office2007安装找不到文件问题
- 对于 升级 Xcode 9.0 beta2 产生的各种问题解决办法。
- .net 获取网站根目录的方法
- sqlserver 2013 sa用户添加sql数据库的映射 无法使用特殊主体sa