决策树算法

来源:互联网 发布:linux如何卸载jdk1.7 编辑:程序博客网 时间:2024/05/21 17:34

ID3决策树算法类似算法流程图。



决策树算法

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。

缺点:可能会产生过度匹配问题。

适用数据类型:数值型和标称型

基于Python的实现代码:

1)准备子函数

[python] view plain copy print?
  1. # -*- coding: cp936 -*-  
  2.   
  3. from math import log  
  4. import operator  
  5.   
  6. def createDataSet():#创建数据集  
  7.     dataSet = [[11'yes'],  
  8.                [11'yes'],  
  9.                [10'no'],  
  10.                [01'no'],  
  11.                [01'no']]  
  12.     labels = ['no surfacing','flippers']  
  13.     #change to discrete values  
  14.     return dataSet, labels  
  15.   
  16. def calcShannonEnt(dataSet):  
  17.     numEntries = len(dataSet)     #计算数据集的长度  
  18.     labelCounts = {}              #定义一个label字典,统计每个label出现的次数,键值为label,值为对应label出现的次数  
  19.     for featVec in dataSet:       #the the number of unique elements and their occurance  
  20.         currentLabel = featVec[-1]#数据集每个元素都是一个列表,每个元素列表的最后一列为label  
  21.         if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 #判断当前label是否已经存在字典键值列表中,没有存在的话,将当前label加入字典,并设置对应值为0  
  22.         labelCounts[currentLabel] += 1                                           #否则,当前label出现次数累加  
  23.     shannonEnt = 0.0  
  24.     for key in labelCounts:  
  25.         prob = float(labelCounts[key])/numEntries #计算每个label出现的概率  
  26.         shannonEnt -= prob * log(prob,2)          #计算数据集的(香农信息熵)信息熵,其中log base 2  
  27.     return shannonEnt  
  28.       
  29. def splitDataSet(dataSet, axis, value):#按照给定特征划分数据集:待划分数据集,划分数据集的特征,需要返回的特征值  
  30.     retDataSet = []                    #定义一个空列表,即:子数据集  
  31.     for featVec in dataSet:  
  32.         if featVec[axis] == value:                  #判断待划分数据集中元素列表指定位置的特征是否与需要返回的特征值匹配  
  33.             reducedFeatVec = featVec[:axis]         #chop out axis used for splitting  
  34.             reducedFeatVec.extend(featVec[axis+1:]) #获取待划分数据集中元素列表的子元素列表(已经裁剪掉指定数据集的特征)  
  35.             retDataSet.append(reducedFeatVec)       #添加获取的子元素列表到子数据集中  
  36.     return retDataSet  
  37.       
  38. def chooseBestFeatureToSplit(dataSet):     #选择最好的数据集划分方式--以不同特征划分子数据集的信息熵增益(或者数据集信息熵减少)大小为依据!  
  39.     numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels  
  40.     baseEntropy = calcShannonEnt(dataSet)  #计算整个数据集的信息熵  
  41.     bestInfoGain = 0.0; bestFeature = -1  
  42.     for i in range(numFeatures):                      #iterate over all the features  
  43.         featList = [example[i] for example in dataSet]#create a list of all the examples of this feature 运用到列表推导式  
  44.         uniqueVals = set(featList)                    #get a set of unique values  
  45.         newEntropy = 0.0  
  46.         for value in uniqueVals:  
  47.             subDataSet = splitDataSet(dataSet, i, value)  
  48.             prob = len(subDataSet)/float(len(dataSet))  
  49.             newEntropy += prob * calcShannonEnt(subDataSet)       
  50.         infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy  
  51.         if (infoGain > bestInfoGain):           #compare this to the best gain so far  
  52.             bestInfoGain = infoGain             #if better than current best, set to best  
  53.             bestFeature = i  
  54.     return bestFeature                          #returns an integer  
2) 构建决策树
[python] view plain copy print?
  1. def majorityCnt(classList):                     #运用多数表决方法判定label不唯一时,叶子节点的分类  
  2.     classCount={}  
  3.     for vote in classList:  
  4.         if vote not in classCount.keys(): classCount[vote] = 0  
  5.         classCount[vote] += 1  
  6.     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)  
  7.     return sortedClassCount[0][0]               #返回label出现次数最多的所属分类  
  8.   
  9. def createTree(dataSet,labels):                 #创建决策树  
  10.     classList = [example[-1for example in dataSet]  
  11.     if classList.count(classList[0]) == len(classList): #list.count(list[0])返回指定位置0对应值list[0],出现的次数  
  12.         return classList[0]                     #stop splitting when all of the classes are equal  
  13.     if len(dataSet[0]) == 1:                    #stop splitting when there are no more features in dataSet  
  14.         return majorityCnt(classList)  
  15.     bestFeat = chooseBestFeatureToSplit(dataSet)  
  16.     bestFeatLabel = labels[bestFeat]  
  17.     myTree = {bestFeatLabel:{}}  
  18.     del(labels[bestFeat])                      #删除已经使用的最佳划分数据集特征  
  19.     featValues = [example[bestFeat] for example in dataSet]  
  20.     uniqueVals = set(featValues)  
  21.     for value in uniqueVals:  
  22.         subLabels = labels[:]                  #copy all of labels, so trees don't mess up existing labels  
  23.         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)  
  24.     return myTree                              
  25.       
  26. def classify(inputTree,featLabels,testVec):    #使用决策树分类函数进行分类  
  27.     firstStr = inputTree.keys()[0]  
  28.     secondDict = inputTree[firstStr]  
  29.     featIndex = featLabels.index(firstStr)  
  30.     key = testVec[featIndex]  
  31.     valueOfFeat = secondDict[key]  
  32.     if isinstance(valueOfFeat, dict):          #判断是否为字典类型的节点,如果是,则该节点为判断节点,否则,该节点为叶子节点  
  33.         classLabel = classify(valueOfFeat, featLabels, testVec)  
  34.     else: classLabel = valueOfFeat  
  35.     return classLabel  
  36.   
  37. def storeTree(inputTree,filename): #利用pickle模块存储已经创建好的决策树,以便后续使用中无需重新构建             
  38.     import pickle  
  39.     fw = open(filename,'w')  
  40.     pickle.dump(inputTree,fw)  
  41.     fw.close()  
  42.       
  43. def grabTree(filename):  
  44.     import pickle  
  45.     fr = open(filename)  
  46.     return pickle.load(fr)  

3) 程序运行截图:(这里用的pythonxy里面的IPython(sh)交换环境)



实例测试:

lenses.txt内容如下所示:

[plain] view plain copy print?
  1. young   myope   no  reduced no lenses  
  2. young   myope   no  normal  soft  
  3. young   myope   yes reduced no lenses  
  4. young   myope   yes normal  hard  
  5. young   hyper   no  reduced no lenses  
  6. young   hyper   no  normal  soft  
  7. young   hyper   yes reduced no lenses  
  8. young   hyper   yes normal  hard  
  9. pre myope   no  reduced no lenses  
  10. pre myope   no  normal  soft  
  11. pre myope   yes reduced no lenses  
  12. pre myope   yes normal  hard  
  13. pre hyper   no  reduced no lenses  
  14. pre hyper   no  normal  soft  
  15. pre hyper   yes reduced no lenses  
  16. pre hyper   yes normal  no lenses  
  17. presbyopic  myope   no  reduced no lenses  
  18. presbyopic  myope   no  normal  no lenses  
  19. presbyopic  myope   yes reduced no lenses  
  20. presbyopic  myope   yes normal  hard  
  21. presbyopic  hyper   no  reduced no lenses  
  22. presbyopic  hyper   no  normal  soft  
  23. presbyopic  hyper   yes reduced no lenses  
  24. presbyopic  hyper   yes normal  no lenses  


基于matplotlib模块的python绘图代码如下所示:

[python] view plain copy print?
  1. import matplotlib.pyplot as plt  
  2.   
  3. decisionNode = dict(boxstyle="sawtooth", fc="0.8")  
  4. leafNode = dict(boxstyle="round4", fc="0.8")  
  5. arrow_args = dict(arrowstyle="<-")  
  6.   
  7. def getNumLeafs(myTree):  
  8.     numLeafs = 0  
  9.     firstStr = myTree.keys()[0]  
  10.     secondDict = myTree[firstStr]  
  11.     for key in secondDict.keys():  
  12.         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes  
  13.             numLeafs += getNumLeafs(secondDict[key])  
  14.         else:   numLeafs +=1  
  15.     return numLeafs  
  16.   
  17. def getTreeDepth(myTree):  
  18.     maxDepth = 0  
  19.     firstStr = myTree.keys()[0]  
  20.     secondDict = myTree[firstStr]  
  21.     for key in secondDict.keys():  
  22.         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes  
  23.             thisDepth = 1 + getTreeDepth(secondDict[key])  
  24.         else:   thisDepth = 1  
  25.         if thisDepth > maxDepth: maxDepth = thisDepth  
  26.     return maxDepth  
  27.   
  28. def plotNode(nodeTxt, centerPt, parentPt, nodeType):  
  29.     createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',  
  30.              xytext=centerPt, textcoords='axes fraction',  
  31.              va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )  
  32.       
  33. def plotMidText(cntrPt, parentPt, txtString):  
  34.     xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  
  35.     yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]  
  36.     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)  
  37.   
  38. def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on  
  39.     numLeafs = getNumLeafs(myTree)      #this determines the x width of this tree  
  40.     depth = getTreeDepth(myTree)  
  41.     firstStr = myTree.keys()[0]         #the text label for this node should be this  
  42.     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)  
  43.     plotMidText(cntrPt, parentPt, nodeTxt)  
  44.     plotNode(firstStr, cntrPt, parentPt, decisionNode)  
  45.     secondDict = myTree[firstStr]  
  46.     plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  
  47.     for key in secondDict.keys():  
  48.         if type(secondDict[key]).__name__=='dict':    #test to see if the nodes are dictonaires, if not they are leaf nodes     
  49.             plotTree(secondDict[key],cntrPt,str(key)) #recursion  
  50.         else:                                         #it's a leaf node print the leaf node  
  51.             plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW  
  52.             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)  
  53.             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))  
  54.     plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD  
  55. #if you do get a dictonary you know it's a tree, and the first element will be another dict  
  56.   
  57. def createPlot(inTree):  
  58.     fig = plt.figure(1, facecolor='white')  
  59.     fig.clf()  
  60.     axprops = dict(xticks=[], yticks=[])  
  61.     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks  
  62.     #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses   
  63.     plotTree.totalW = float(getNumLeafs(inTree))  
  64.     plotTree.totalD = float(getTreeDepth(inTree))  
  65.     plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;  
  66.     plotTree(inTree, (0.5,1.0), '')  
  67.     plt.show()  
  68.   
  69. #def createPlot():  
  70. #    fig = plt.figure(1, facecolor='white')  
  71. #    fig.clf()  
  72. #    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses   
  73. #    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)  
  74. #    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)  
  75. #    plt.show()  
  76.   
  77. def retrieveTree(i):  
  78.     listOfTrees =[{'no surfacing': {0'no'1: {'flippers': {0'no'1'yes'}}}},  
  79.                   {'no surfacing': {0'no'1: {'flippers': {0: {'head': {0'no'1'yes'}}, 1'no'}}}}  
  80.                   ]  
  81.     return listOfTrees[i]  
  82.   
  83. #createPlot(thisTree)  



原创粉丝点击