id3算法(python代码)

来源:互联网 发布:mysql 创建覆盖索引 编辑:程序博客网 时间:2024/06/08 09:33

1. 该python实现没有考虑到overfitting。

[python] view plaincopy
  1. # coding=utf-8  
  2. from numpy import *  
  3. from math import log  
  4. #下面的函数用来计算香农熵 H=sum(-p(xi)log(p(xi)))  其中xi指的是每种类别所占的比例  
  5. def calcShannonEnt(dataSet):  
  6.     numEntries=len(dataSet)#数据集的行数  
  7.     labelCounts={}#数据集的类别标签  
  8.     for featVec in dataSet:  
  9.         currentLabel=featVec[-1]#读取数据集的类别  
  10.         if currentLabel not in labelCounts.keys():  
  11.             labelCounts[currentLabel]=0  
  12.         labelCounts[currentLabel]+=1  
  13.     shannonEnt=0.0#香农熵  
  14.     for key in labelCounts:  
  15.         prob=float(labelCounts[key])/numEntries  
  16.         shannonEnt-=prob*log(prob,2)  
  17.     return shannonEnt#返回香农熵  
  18. #创建数据集,返回数据集合和标签列表,该标签是人工标签,仅仅是表明真实类别的名字而已  
  19. def createDataSet():  
  20.     dataSet=[  
  21.         [1,1,'yes'],  
  22.         [1,1,'yes'],  
  23.         [1,0,'no'],  
  24.         [0,1,'no'],  
  25.         [0,1,'no']  
  26.         ]  
  27.     labels=['LabelOne','LableTwo']  
  28.     return dataSet,labels  
  29. #划分数据集,axis是特征值,value是该特征值对应的value  
  30. #将dataSet中特征值等于value的样本点筛选出来  
  31. def splitDataSet(dataSet,axis,value):  
  32.     retDataSet=[]  
  33.     for featVec in dataSet:  
  34.         if featVec[axis]==value:  
  35.             reducedFeatVec=featVec[:axis]  
  36.             reducedFeatVec=featVec[axis+1:]  
  37.             retDataSet.append(reducedFeatVec)  
  38.     return retDataSet  
  39. #遍历数据集找到最好的数据集划分方式  
  40. #H=sum(-p(xi)log(p(xi)))  
  41. def chooseBestFeatureToSplit(dataSet):  
  42.     numFeatures=len(dataSet[0])-1 #numFeature是样本xi的特征数目,请注意最后一个是类别标签  
  43.     baseEntroy=calcShannonEnt(dataSet)#计算原始数据集的香农熵  
  44.     bestInfoGain=0.0;#最大信息增益  
  45.     bestFeature=-1#采用哪个特征分裂  
  46.     #从现有的所有特征之中找到最合适的分裂,是的信息的增益最大  
  47.     for i in range(numFeatures):  
  48.         featList=[example[i] for example in dataSet]#找到第i个feature的所有可能取值 ,有可能有重复的  
  49.         uniqueVals=set(featList)#去重  
  50.         newEntroy=0.0  
  51.         for value in uniqueVals:  
  52.             subDataSet=splitDataSet(dataSet,i,value)  
  53.             #Gain(S,A)=Entroy(S)-Sum((Sv/S)*Entroy(Sv))  信息增益  
  54.             prob=len(subDataSet)/float(len(dataSet))#Sv/S  
  55.             newEntroy+=prob*calcShannonEnt(subDataSet)  
  56.         infoGain=baseEntroy-newEntroy#计算出信息增益  
  57.         if (infoGain>bestInfoGain):#若信息增益大于当前最大的信息增益则需要记录最大的信息增益以及其对应的分裂Feature  
  58.             bestInfoGain=infoGain  
  59.             bestFeature=i  
  60.     return bestFeature#返回按照哪个feature进行Split  
  61. #当特征耗费完的时候,dataSet中仍然的分类标签依然不纯,那么就要用少数服从多数的原则来决定分类了  
  62. def majorityCnt(classList):  
  63.     classCount={}#每种类别包含的个数  
  64.     for vote in classList:  
  65.         if vote not in classCount.keys():  
  66.             classCount[vote]=0  
  67.         classCount[vote]+=1#统计每种类别的个数  
  68.     sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)  
  69.     return sortedClassCount[0][0]  
  70. #创建ID3树  
  71. #首先检测当前数据集是否是纯净的,若是纯净的直接返回类标签  
  72. #再次检测当前数据集的feature是否已经被完全耗费完毕了,若耗费完毕了,直接返回  
  73. #之后构造树形结构,递归构造树  
  74. def createTree(dataSet,labels):  
  75.     classList=[example[-1for example in dataSet]#获取类别标签  
  76.     if classList.count(classList[0])==len(classList):#若当前数据集的所有标签全部都一样了  
  77.         return classList[0]  
  78.     if len(dataSet[0])==1:#特征完全耗费完毕了  
  79.         return majorityCnt(dataSet)  
  80.     #获取最好的分类feature  
  81.     bestFeat=chooseBestFeatureToSplit(dataSet)  
  82.     bestFeatLabel=labels[bestFeat]#获取分类属性的具体名字  
  83.     myTree={bestFeatLabel:{}}#创建树形结构  
  84.     del(labels[bestFeat])#删除分类属性 每次都要耗费掉一个属性  
  85.     #获取最好分类Feature所对应的所有值  
  86.     featValues=[example[bestFeat] for example in dataSet]  
  87.     uniqueVals=set(featValues)  
  88.     for value in uniqueVals:  
  89.         subLabels=labels[:]#子标签  
  90.         myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)  
  91.     return myTree  
  92.   
  93. #测试ID3分类效果 inputTree是决策树本身  
  94. #featLabels是人工表示的类别类型  
  95. #testVec是测试向量  
  96. def classify(inputTree,featLabels,testVec):  
  97.     firstStr=inputTree.keys()[0]#第一个分类的名字  
  98.     secondDict=inputTree[firstStr]#第一个分类名字对应的值{或者是集合或者是标签}  
  99.     featIndex=featLabels.index(firstStr)#当前分类对应的feature的下标  
  100.     for key in secondDict.keys():  
  101.         if testVec[featIndex]==key:#若当前项是集合 递归下去  
  102.             if type(secondDict[key]).__name__=='dict':  
  103.                 classLabel=classify(secondDict[key],featLabels,testVec)  
  104.             else:#若当前项是标签  
  105.                 classLabel=secondDict[key]  
  106.     return classLabel  
  107.   
  108. myDat,labels=createDataSet()  
  109. t=createTree(myDat,labels)  
  110. print t  
  111. labels=['LabelOne','LableTwo']  
  112. print classify(t,labels,[1,0])  
  113. print classify(t,labels,[1,1])  

结果>>> ================================ RESTART ================================
>>> 
{'LabelOne': {0: 'no', 1: {'LableTwo': {0: 'no', 1: 'yes'}}}}
no
>>> ================================ RESTART ================================
>>> 
{'LabelOne': {0: 'no', 1: {'LableTwo': {0: 'no', 1: 'yes'}}}}
no
yes
0 0
原创粉丝点击