决策树

来源:互联网 发布:linux用ubuntu 编辑:程序博客网 时间:2024/06/15 05:32

字典

字典类似于你通过联系人名字查找地址和联系人详细情况的地址簿,即,我们把(名字)和(详细情况)联系在一起。注意,键必须是唯一的,就像如果有两个人恰巧同名的话,你无法找到正确的信息。

注意,你只能使用不可变的对象(比如字符串)来作为字典的键,但是你可以不可变或可变的对象作为字典的值。基本说来就是,你应该只使用简单的对象作为键。

键值对在字典中以这样的方式标记:d = {key1 : value1, key2 : value2 }。注意它们的键/值对用冒号分割,而各个对用逗号分割,所有这些都包括在花括号中。

记住字典中的键/值对是没有顺序的。如果你想要一个特定的顺序,那么你应该在使用前自己对它们排序。

字典是dict类的实例/对象。

#!/usr/bin/python

# Filename: using_dict.py
# 'ab' is short for 'a'ddress'b'ook

ab = {       'Swaroop'   'swaroopch@byteofpython.info',
             'Larry'     'larry@wall.org',
             'Matsumoto' 'matz@ruby-lang.org',
             'Spammer'   'spammer@hotmail.com'
     }
print "Swaroop's address is %s" % ab['Swaroop']
# Adding a key/value pair
ab['Guido'] = 'guido@python.org'
# Deleting a key/value pair
del ab['Spammer']
print '\nThere are %d contacts in the address-book\n' len(ab)
for name, address in ab.items():
    print 'Contact %s at %s' % (name, address)
if 'Guido' in ab: # OR ab.has_key('Guido')
    print "\nGuido's address is %s" % ab['Guido']

输出

$ python using_dict.py
Swaroop's address is swaroopch@byteofpython.info
There are 4 contacts in the address-book
Contact Swaroop at swaroopch@byteofpython.info
Contact Matsumoto at matz@ruby-lang.org
Contact Larry at larry@wall.org
Contact Guido at guido@python.org
Guido's address is guido@python.org

它如何工作

我们使用已经介绍过的标记创建了字典ab。然后我们使用在列表和元组章节中已经讨论过的索引操作符来指定键,从而使用键/值对。我们可以看到字典的语法同样十分简单。

我们可以使用索引操作符来寻址一个键并为它赋值,这样就增加了一个新的键/值对,就像在上面的例子中我们对Guido所做的一样。

我们可以使用我们的老朋友——del语句来删除键/值对。我们只需要指明字典和用索引操作符指明要删除的键,然后把它们传递给del语句就可以了。执行这个操作的时候,我们无需知道那个键所对应的值。

接下来,我们使用字典的items方法,来使用字典中的每个键/值对。这会返回一个元组的列表,其中每个元组都包含一对项目——键与对应的值。我们抓取这个对,然后分别赋给for..in循环中的变量nameaddress然后在for-块中打印这些值。

我们可以使用in操作符来检验一个键/值对是否存在,或者使用dict类的has_key方法。你可以使用help(dict)来查看dict类的完整方法列表。

关键字参数与字典。如果换一个角度看待你在函数中使用的关键字参数的话,你已经使用了字典了!只需想一下——你在函数定义的参数列表中使用的键/值对。当你在函数中使用变量的时候,它只不过是使用一个字典的键(这在编译器设计的术语中被称作 符号表 )

决策树的优势就在于数据形式非常容易理解,而kNN的最大缺点就是无法给出数据的内在含义。

1:简单概念描述

       决策树的类型有很多,有CART、ID3和C4.5等,其中CART是基于基尼不纯度(Gini)的,这里不做详解,而ID3和C4.5都是基于信息熵的,它们两个得到的结果都是一样的,本次定义主要针对ID3算法。下面我们介绍信息熵的定义。

       事件ai发生的概率用p(ai)来表示,而-log2(p(ai))表示为事件ai的不确定程度,称为ai的自信息量,sum(p(ai)*I(ai))称为信源S的平均信息量—信息熵。

    决策树学习采用的是自顶向下的递归方法,其基本思想是以信息熵为度量构造一棵熵值下降最快的树,到叶子节点处的熵值为零,此时每个叶节点中的实例都属于同一类。

       ID3的原理是基于信息熵增益达到最大,设原始问题的标签有正例和负例,p和n表示其相应的个数。则原始问题的信息熵为

       其中N为该特征所取值的个数,比如{rain,sunny},则N即为2

              Gain = BaseEntropy – newEntropy

ID3的原理即使Gain达到最大值。信息增益即为熵的减少或者是数据无序度的减少。

ID3易出现的问题:如果是取值更多的属性,更容易使得数据更“纯”(尤其是连续型数值),其信息增益更大,决策树会首先挑选这个属性作为树的顶点。结果训练出来的形状是一棵庞大且深度很浅的树,这样的划分是极为不合理的。 此时可以采用C4.5来解决

C4.5的思想是最大化Gain除以下面这个公式即得到信息增益率:


其中底为2

2:Python代码的实现

(1)   计算信息熵

[python] view plain copy
 print?在CODE上查看代码片派生到我的代码片
  1. ##计算给定数据集的信息熵  
  2. def calcShannonEnt(dataSet):  
  3.     numEntries = len(dataSet)  
  4.     labelCounts = {}  
  5.     for featVec in dataSet:  
  6.         currentLabel = featVec[-1]  
  7.         if currentLabel not in labelCounts.keys():     #为所有可能分类创建字典  
  8.             labelCounts[currentLabel] = 0  
  9.         labelCounts[currentLabel] += 1  
  10.     shannonEnt = 0.0  
  11.     for key in labelCounts:  
  12.         prob = float(labelCounts[key])/numEntries  
  13.         shannonEnt -= prob * log(prob,2)   #以2为底数求对数  
  14.     return shannonEnt  

(2)   创建数据集

[python] view plain copy
 print?在CODE上查看代码片派生到我的代码片
  1. #创建数据  
  2. def createDataSet():  
  3.     dataSet = [[1,1,'yes'],  
  4.                [1,1,'yes'],  
  5.                [1,0,'no'],  
  6.                [0,1,'no'],  
  7.                [0,1,'no']]  
  8.     labels = ['no surfacing''flippers']  
  9.     return dataSet, labels  

(3)   划分数据集

[python] view plain copy
 print?在CODE上查看代码片派生到我的代码片
  1. #依据特征划分数据集  axis代表第几个特征  value代表该特征所对应的值  返回的是划分后的数据集  
  2. def splitDataSet(dataSet, axis, value):  
  3.     retDataSet = []  
  4.     for featVec in dataSet:  
  5.         if featVec[axis] == value:  
  6.             reducedFeatVec = featVec[:axis]  
  7.             reducedFeatVec.extend(featVec[axis+1:])  
  8.             retDataSet.append(reducedFeatVec)  
  9.     return retDataSet  


(4)   选择最好的特征进行划分

[python] view plain copy
 print?在CODE上查看代码片派生到我的代码片
  1. #选择最好的数据集(特征)划分方式  返回最佳特征下标  
  2. def chooseBestFeatureToSplit(dataSet):  
  3.     numFeatures = len(dataSet[0]) - 1   #特征个数  
  4.     baseEntropy = calcShannonEnt(dataSet)  
  5.     bestInfoGain = 0.0; bestFeature = -1  
  6.     for i in range(numFeatures):   #遍历特征 第i个  
  7.         featureSet = set([example[i] for example in dataSet])   #第i个特征取值集合  
  8.         newEntropy= 0.0  
  9.         for value in featureSet:  
  10.             subDataSet = splitDataSet(dataSet, i, value)  
  11.             prob = len(subDataSet)/float(len(dataSet))  
  12.             newEntropy += prob * calcShannonEnt(subDataSet)   #该特征划分所对应的entropy  
  13.         infoGain = baseEntropy - newEntropy  
  14.         if infoGain > bestInfoGain:  
  15.             bestInfoGain = infoGain  
  16.             bestFeature = i  
  17.     return bestFeature  

注意:这里数据集需要满足以下两个办法:

<1>所有的列元素都必须具有相同的数据长度

<2>数据的最后一列或者每个实例的最后一个元素是当前实例的类别标签。

(5)   创建树的代码

Python用字典类型来存储树的结构  返回的结果是myTree-字典

[python] view plain copy
 print?在CODE上查看代码片派生到我的代码片
  1. #创建树的函数代码   python中用字典类型来存储树的结构 返回的结果是myTree-字典  
  2. def createTree(dataSet, labels):  
  3.     classList = [example[-1for example in dataSet]  
  4.     if classList.count(classList[0]) == len(classList):    #类别完全相同则停止继续划分  返回类标签-叶子节点  
  5.         return classList[0]  
  6.     if len(dataSet[0]) == 1:  
  7.         return majorityCnt(classList)       #遍历完所有的特征时返回出现次数最多的  
  8.     bestFeat = chooseBestFeatureToSplit(dataSet)  
  9.     bestFeatLabel = labels[bestFeat]  
  10.     myTree = {bestFeatLabel:{}}  
  11.     del(labels[bestFeat])  
  12.     featValues = [example[bestFeat] for example in dataSet]    #得到的列表包含所有的属性值  
  13.     uniqueVals = set(featValues)  
  14.     for value in uniqueVals:  
  15.         subLabels = labels[:]  
  16.         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)  
  17.     return myTree  

其中递归结束当且仅当该类别中标签完全相同或者遍历所有的特征此时返回次数最多的


其中当所有的特征都用完时,采用多数表决的方法来决定该叶子节点的分类,即该叶节点中属于某一类最多的样本数,那么我们就说该叶节点属于那一类!。代码如下:

[python] view plain copy
 print?在CODE上查看代码片派生到我的代码片
  1. #多数表决的方法决定叶子节点的分类 ----  当所有的特征全部用完时仍属于多类  
  2. def majorityCnt(classList):  
  3.     classCount = {}  
  4.     for vote in classList:  
  5.         if vote not in classCount.key():  
  6.             classCount[vote] = 0;  
  7.         classCount[vote] += 1  
  8.     sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)  #排序函数 operator中的  
  9.     return sortedClassCount[0][0]  

即为如果数据集已经处理了所有的属性,但是类标签依然不是唯一的,此时我们要决定如何定义该叶子节点,在这种情况下,我们通常采用多数表决的方法来决定该叶子节点的分类。

(6)   使用决策树执行分类

[python] view plain copy
 print?在CODE上查看代码片派生到我的代码片
  1. #使用决策树执行分类  
  2. def classify(inputTree, featLabels, testVec):  
  3.     firstStr = inputTree.keys()[0]  
  4.     secondDict = inputTree[firstStr]  
  5.     featIndex = featLabels.index(firstStr)   #index方法查找当前列表中第一个匹配firstStr变量的元素的索引  
  6.     for key in secondDict.keys():  
  7.         if testVec[featIndex] == key:  
  8.             if type(secondDict[key]).__name__ == 'dict':  
  9.                 classLabel = classify(secondDict[key], featLabels, testVec)  
  10.             else: classLabel = secondDict[key]  
  11.     return classLabel  


注意递归的思想很重要。

(7)   决策树的存储

构造决策树是一个很耗时的任务。为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。为了解决这个问题,需要使用python模块pickle序列化对象,序列化对象可以在磁盘上保存对象,并在需要的时候读取出来。

[python] view plain copy
 print?在CODE上查看代码片派生到我的代码片
  1. #决策树的存储  
  2. def storeTree(inputTree, filename):         #pickle序列化对象,可以在磁盘上保存对象  
  3.     import pickle  
  4.     fw = open(filename, 'w')  
  5.     pickle.dump(inputTree, fw)  
  6.     fw.close()  
  7.   
  8.   
  9. def grabTree(filename):               #并在需要的时候将其读取出来  
  10.     import pickle  
  11.     fr = open(filename)  
  12.     return pickle.load(fr)  


3:matplotlib 注解

Matplotlib提供了一个注解工具annotations,非常有用,它可以在数据图形上添加文本注释。注解通常用于解释数据的内容。

这段代码我也没看懂,所以只给出书上代码

[python] view plain copy
 print?在CODE上查看代码片派生到我的代码片
  1. # -*- coding: cp936 -*-  
  2. import matplotlib.pyplot as plt  
  3.   
  4. decisionNode = dict(boxstyle = 'sawtooth', fc = '0.8')  
  5. leafNode = dict(boxstyle = 'round4', fc = '0.8')  
  6. arrow_args = dict(arrowstyle = '<-')  
  7.   
  8. def plotNode(nodeTxt, centerPt, parentPt, nodeType):  
  9.     createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction',\  
  10.                             xytext = centerPt, textcoords = 'axes fraction',\  
  11.                             va = 'center', ha = 'center', bbox = nodeType, \  
  12.                             arrowprops = arrow_args)  
  13.   
  14. # 使用文本注解绘制树节点  
  15. def createPlot():  
  16.     fig = plt.figure(1, facecolor = 'white')  
  17.     fig.clf()  
  18.     createPlot.ax1 = plt.subplot(111, frameon = False)  
  19.     plotNode('a decision node', (0.5,0.1), (0.1,0.5), decisionNode)  
  20.     plotNode('a leaf node', (0.80.1), (0.3,0.8), leafNode)  
  21.     plt.show()  
  22.   
  23.   
  24. #获取叶子节点数目和树的层数  
  25. def getNumLeafs(myTree):  
  26.     numLeafs = 0  
  27.     firstStr = myTree.keys()[0]  
  28.     secondDict = myTree[firstStr]  
  29.     for key in secondDict.keys():  
  30.         if(type(secondDict[key]).__name__ == 'dict'):  
  31.             numLeafs += getNumLeafs(secondDict[key])  
  32.         else: numLeafs += 1  
  33.     return numLeafs  
  34.   
  35. def getTreeDepth(myTree):  
  36.     maxDepth = 0  
  37.     firstStr = myTree.keys()[0]  
  38.     secondDict = myTree[firstStr]  
  39.     for key in secondDict.keys():  
  40.         if(type(secondDict[key]).__name__ == 'dict'):  
  41.             thisDepth = 1+ getTreeDepth(secondDict[key])  
  42.         else: thisDepth = 1  
  43.         if thisDepth > maxDepth: maxDepth = thisDepth  
  44.     return maxDepth  
  45.   
  46.   
  47. #更新createPlot代码以得到整棵树  
  48. def plotMidText(cntrPt, parentPt, txtString):  
  49.     xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  
  50.     yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]  
  51.     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)  
  52.   
  53. def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on  
  54.     numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree  
  55.     depth = getTreeDepth(myTree)  
  56.     firstStr = myTree.keys()[0]     #the text label for this node should be this  
  57.     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)  
  58.     plotMidText(cntrPt, parentPt, nodeTxt)  
  59.     plotNode(firstStr, cntrPt, parentPt, decisionNode)  
  60.     secondDict = myTree[firstStr]  
  61.     plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  
  62.     for key in secondDict.keys():  
  63.         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes     
  64.             plotTree(secondDict[key],cntrPt,str(key))        #recursion  
  65.         else:   #it's a leaf node print the leaf node  
  66.             plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW  
  67.             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)  
  68.             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))  
  69.     plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD  
  70. #if you do get a dictonary you know it's a tree, and the first element will be another dict  
  71.   
  72. def createPlot(inTree):  
  73.     fig = plt.figure(1, facecolor='white')  
  74.     fig.clf()  
  75.     axprops = dict(xticks=[], yticks=[])  
  76.     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks  
  77.     #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses   
  78.     plotTree.totalW = float(getNumLeafs(inTree))  
  79.     plotTree.totalD = float(getTreeDepth(inTree))  
  80.     plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;  
  81.     plotTree(inTree, (0.5,1.0), '')  
  82.     plt.show()  

其中index方法为查找当前列表中第一个匹配firstStr的元素 返回的为索引。

4:使用决策树预测隐形眼镜类型


0 0