决策树剪枝简单python实现

来源:互联网 发布:防狼喷雾剂 淘宝 编辑:程序博客网 时间:2024/06/05 06:44

决策树是一种依托决策而建立起来的一种树。在机器学习中,决策树是一种预测模型,代表的是一种对象属性与对象值之间的一种映射关系,每一个节点代表某个对象,树中的每一个分叉路径代表某个可能的属性值,而每一个叶子节点则对应从根节点到该叶子节点所经历的路径所表示的对象的值。决策树仅有单一输出,如果有多个输出,可以分别建立独立的决策树以处理不同的输出。

ID3算法:ID3算法是决策树的一种,是基于奥卡姆剃刀原理的,即用尽量用较少的东西做更多的事。ID3算法,即Iterative Dichotomiser 3,迭代二叉树3代,是Ross Quinlan发明的一种决策树算法,这个算法的基础就是上面提到的奥卡姆剃刀原理,越是小型的决策树越优于大的决策树,尽管如此,也不总是生成最小的树型结构,而是一个启发式算法。在信息论中,期望信息越小,那么信息增益就越大,从而纯度就越高。ID3算法的核心思想就是以信息增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍历可能的决策空间。
信息熵,将其定义为离散随机事件出现的概率,一个系统越是有序,信息熵就越低,反之一个系统越是混乱,它的信息熵就越高。所以信息熵可以被认为是系统有序化程度的一个度量。

基尼指数:在CART里面划分决策树的条件是采用Gini Index,定义如下:gini(T)=1−sumnj=1p2j。其中,( p_j )是类j在T中的相对频率,当类在T中是倾斜的时,gini(T)会最小。将T划分为T1(实例数为N1)和T2(实例数为N2)两个子集后,划分数据的Gini定义如下:ginisplit(T)=fracN1Ngini(T1)+fracN2Ngini(T2),然后选择其中最小的(gini_{split}(T) )作为结点划分决策树
具体实现
首先用函数calcShanno计算数据集的香农熵,给所有可能的分类创建字典 def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
# 给所有可能分类创建字典
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
# 以2为底数计算香农熵
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt

# 对离散变量划分数据集,取出该特征取值为value的所有样本def splitDataSet(dataSet, axis, value):    retDataSet = []    for featVec in dataSet:        if featVec[axis] == value:            reducedFeatVec = featVec[:axis]            reducedFeatVec.extend(featVec[axis + 1:])            retDataSet.append(reducedFeatVec)    return retDataSet

对连续变量划分数据集,direction规定划分的方向, 决定是划分出小于value的数据样本还是大于value的数据样本集

    numFeatures = len(dataSet[0]) - 1    baseEntropy = calcShannonEnt(dataSet)    bestInfoGain = 0.0    bestFeature = -1    bestSplitDict = {}    for i in range(numFeatures):        featList = [example[i] for example in dataSet]        # 对连续型特征进行处理        if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':            # 产生n-1个候选划分点            sortfeatList = sorted(featList)            splitList = []            for j in range(len(sortfeatList) - 1):                splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)            bestSplitEntropy = 10000            slen = len(splitList)            # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点            for j in range(slen):                value = splitList[j]                newEntropy = 0.0                subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)                subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)                prob0 = len(subDataSet0) / float(len(dataSet))                newEntropy += prob0 * calcShannonEnt(subDataSet0)                prob1 = len(subDataSet1) / float(len(dataSet))                newEntropy += prob1 * calcShannonEnt(subDataSet1)                if newEntropy < bestSplitEntropy:                    bestSplitEntropy = newEntropy                    bestSplit = j            # 用字典记录当前特征的最佳划分点            bestSplitDict[labels[i]] = splitList[bestSplit]            infoGain = baseEntropy - bestSplitEntropy        # 对离散型特征进行处理        else:            uniqueVals = set(featList)            newEntropy = 0.0            # 计算该特征下每种划分的信息熵            for value in uniqueVals:                subDataSet = splitDataSet(dataSet, i, value)                prob = len(subDataSet) / float(len(dataSet))                newEntropy += prob * calcShannonEnt(subDataSet)            infoGain = baseEntropy - newEntropy        if infoGain > bestInfoGain:            bestInfoGain = infoGain            bestFeature = i    # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理    # 即是否小于等于bestSplitValue    if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':        bestSplitValue = bestSplitDict[labels[bestFeature]]        labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)        for i in range(shape(dataSet)[0]):            if dataSet[i][bestFeature] <= bestSplitValue:                dataSet[i][bestFeature] = 1            else:                dataSet[i][bestFeature] = 0    return bestFeature
def chooseBestFeatureToSplit(dataSet, labels):    numFeatures = len(dataSet[0]) - 1    baseEntropy = calcShannonEnt(dataSet)    bestInfoGain = 0.0    bestFeature = -1    bestSplitDict = {}    for i in range(numFeatures):        featList = [example[i] for example in dataSet]        # 对连续型特征进行处理        if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':            # 产生n-1个候选划分点            sortfeatList = sorted(featList)            splitList = []            for j in range(len(sortfeatList) - 1):                splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)            bestSplitEntropy = 10000            slen = len(splitList)            # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点            for j in range(slen):                value = splitList[j]                newEntropy = 0.0                subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)                subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)                prob0 = len(subDataSet0) / float(len(dataSet))                newEntropy += prob0 * calcShannonEnt(subDataSet0)                prob1 = len(subDataSet1) / float(len(dataSet))                newEntropy += prob1 * calcShannonEnt(subDataSet1)                if newEntropy < bestSplitEntropy:                    bestSplitEntropy = newEntropy                    bestSplit = j            # 用字典记录当前特征的最佳划分点            bestSplitDict[labels[i]] = splitList[bestSplit]            infoGain = baseEntropy - bestSplitEntropy        # 对离散型特征进行处理        else:            uniqueVals = set(featList)            newEntropy = 0.0            # 计算该特征下每种划分的信息熵            for value in uniqueVals:                subDataSet = splitDataSet(dataSet, i, value)                prob = len(subDataSet) / float(len(dataSet))                newEntropy += prob * calcShannonEnt(subDataSet)            infoGain = baseEntropy - newEntropy        if infoGain > bestInfoGain:            bestInfoGain = infoGain            bestFeature = i    # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理    # 即是否小于等于bestSplitValue    if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':        bestSplitValue = bestSplitDict[labels[bestFeature]]        labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)        for i in range(shape(dataSet)[0]):            if dataSet[i][bestFeature] <= bestSplitValue:                dataSet[i][bestFeature] = 1            else:                dataSet[i][bestFeature] = 0    return bestFeature``def classify(inputTree, featLabels, testVec):    firstStr = inputTree.keys()[0]    if u'<=' in firstStr:        featvalue = float(firstStr.split(u"<=")[1])        featkey = firstStr.split(u"<=")[0]        secondDict = inputTree[firstStr]        featIndex = featLabels.index(featkey)        if testVec[featIndex] <= featvalue:            judge = 1        else:            judge = 0        for key in secondDict.keys():            if judge == int(key):                if type(secondDict[key]).__name__ == 'dict':                    classLabel = classify(secondDict[key], featLabels, testVec)                else:                    classLabel = secondDict[key]    else:        secondDict = inputTree[firstStr]        featIndex = featLabels.index(firstStr)        for key in secondDict.keys():            if testVec[featIndex] == key:                if type(secondDict[key]).__name__ == 'dict':                    classLabel = classify(secondDict[key], featLabels, testVec)                else:                    classLabel = secondDict[key]    return classLabel
def majorityCnt(classList):    classCount={}    for vote in classList:        if vote not in classCount.keys():            classCount[vote]=0        classCount[vote]+=1    return max(classCount)def testing_feat(feat, train_data, test_data, labels):    class_list = [example[-1] for example in train_data]    bestFeatIndex = labels.index(feat)    train_data = [example[bestFeatIndex] for example in train_data]    test_data = [(example[bestFeatIndex], example[-1]) for example in test_data]    all_feat = set(train_data)    error = 0.0    for value in all_feat:        class_feat = [class_list[i] for i in range(len(class_list)) if train_data[i] == value]        major = majorityCnt(class_feat)        for data in test_data:            if data[0] == value and data[1] != major:                error += 1.0    # print 'myTree %d' % error    return error

测试

    error = 0.0    for i in range(len(data_test)):        if classify(myTree, labels, data_test[i]) != data_test[i][-1]:            error += 1    # print 'myTree %d' % error    return float(error)def testingMajor(major, data_test):    error = 0.0    for i in range(len(data_test)):        if major != data_test[i][-1]:            error += 1    # print 'major %d' % error    return float(error)**递归产生决策树**```def createTree(dataSet,labels,data_full,labels_full,test_data,mode):    classList=[example[-1] for example in dataSet]    if classList.count(classList[0])==len(classList):       return classList[0]    if len(dataSet[0])==1:       return majorityCnt(classList)    labels_copy = copy.deepcopy(labels)    bestFeat=chooseBestFeatureToSplit(dataSet,labels)    bestFeatLabel=labels[bestFeat]    if mode == "unpro" or mode == "post":        myTree = {bestFeatLabel: {}}    elif mode == "prev":        if testing_feat(bestFeatLabel, dataSet, test_data, labels_copy) < testingMajor(majorityCnt(classList),                                                                                       test_data):            myTree = {bestFeatLabel: {}}        else:            return majorityCnt(classList)    featValues=[example[bestFeat] for example in dataSet]    uniqueVals=set(featValues)    if type(dataSet[0][bestFeat]).__name__ == 'unicode':        currentlabel = labels_full.index(labels[bestFeat])        featValuesFull = [example[currentlabel] for example in data_full]        uniqueValsFull = set(featValuesFull)    del (labels[bestFeat])    for value in uniqueVals:        subLabels = labels[:]        if type(dataSet[0][bestFeat]).__name__ == 'unicode':            uniqueValsFull.remove(value)        myTree[bestFeatLabel][value] = createTree(splitDataSet \                                                      (dataSet, bestFeat, value), subLabels, data_full, labels_full,                                                  splitDataSet \                                                      (test_data, bestFeat, value), mode=mode)    if type(dataSet[0][bestFeat]).__name__ == 'unicode':        for value in uniqueValsFull:            myTree[bestFeatLabel][value] = majorityCnt(classList)    if mode == "post":        if testing(myTree, test_data, labels_copy) > testingMajor(majorityCnt(classList), test_data):            return majorityCnt(classList)    return myTree<div class="se-preview-section-delimiter"></div>```**读入数据**```def load_data(file_name):    with open(r"dd.csv", 'rb') as f:      df = pd.read_csv(f,sep=",")      print(df)      train_data = df.values[:11, 1:].tolist()    print(train_data)    test_data = df.values[11:, 1:].tolist()    labels = df.columns.values[1:-1].tolist()    return train_data, test_data, labels<div class="se-preview-section-delimiter"></div>```测试并绘制树图import matplotlib.pyplot as pltdecisionNode = dict(boxstyle="round4", color='red')  # 定义判断结点形态leafNode = dict(boxstyle="circle", color='grey')  # 定义叶结点形态arrow_args = dict(arrowstyle="<-", color='blue')  # 定义箭头# 计算树的叶子节点数量def getNumLeafs(myTree):    numLeafs = 0    firstSides = list(myTree.keys())    firstStr = firstSides[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            numLeafs += getNumLeafs(secondDict[key])        else:            numLeafs += 1    return numLeafs# 计算树的最大深度def getTreeDepth(myTree):    maxDepth = 0    firstSides = list(myTree.keys())    firstStr = firstSides[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            thisDepth = 1 + getTreeDepth(secondDict[key])        else:            thisDepth = 1        if thisDepth > maxDepth:            maxDepth = thisDepth    return maxDepth# 画节点def plotNode(nodeTxt, centerPt, parentPt, nodeType):    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \                            xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \                            bbox=nodeType, arrowprops=arrow_args)# 画箭头上的文字def plotMidText(cntrPt, parentPt, txtString):    lens = len(txtString)    xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002    yMid = (parentPt[1] + cntrPt[1]) / 2.0    createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):    numLeafs = getNumLeafs(myTree)    depth = getTreeDepth(myTree)    firstSides = list(myTree.keys())    firstStr = firstSides[0]    cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)    plotMidText(cntrPt, parentPt, nodeTxt)    plotNode(firstStr, cntrPt, parentPt, decisionNode)    secondDict = myTree[firstStr]    plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            plotTree(secondDict[key], cntrPt, str(key))        else:            plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW            plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)            plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))    plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalDdef createPlot(inTree):    fig = plt.figure(1, facecolor='white')    fig.clf()    axprops = dict(xticks=[], yticks=[])    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    plotTree.totalW = float(getNumLeafs(inTree))    plotTree.totalD = float(getTreeDepth(inTree))    plotTree.x0ff = -0.5 / plotTree.totalW    plotTree.y0ff = 1.0    plotTree(inTree, (0.5, 1.0), '')    plt.show()

测试

if __name__ == "__main__":    train_data, test_data, labels = load_data("dd.csv")    data_full = train_data[:]    labels_full = labels[:]    mode="post"    mode = "prev"    mode="post"    myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)    createPlot(myTree)    print(json.dumps(myTree, ensure_ascii=False, indent=4))

完整代码如上,选择mode就可以分别得到三种树图

import matplotlib.pyplot as pltdecisionNode = dict(boxstyle="round4", color='red')  # 定义判断结点形态leafNode = dict(boxstyle="circle", color='grey')  # 定义叶结点形态arrow_args = dict(arrowstyle="<-", color='blue')  # 定义箭头# 计算树的叶子节点数量def getNumLeafs(myTree):    numLeafs = 0    firstSides = list(myTree.keys())    firstStr = firstSides[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            numLeafs += getNumLeafs(secondDict[key])        else:            numLeafs += 1    return numLeafs# 计算树的最大深度def getTreeDepth(myTree):    maxDepth = 0    firstSides = list(myTree.keys())    firstStr = firstSides[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            thisDepth = 1 + getTreeDepth(secondDict[key])        else:            thisDepth = 1        if thisDepth > maxDepth:            maxDepth = thisDepth    return maxDepth# 画节点def plotNode(nodeTxt, centerPt, parentPt, nodeType):    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \                            xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \                            bbox=nodeType, arrowprops=arrow_args)# 画箭头上的文字def plotMidText(cntrPt, parentPt, txtString):    lens = len(txtString)    xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002    yMid = (parentPt[1] + cntrPt[1]) / 2.0    createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):    numLeafs = getNumLeafs(myTree)    depth = getTreeDepth(myTree)    firstSides = list(myTree.keys())    firstStr = firstSides[0]    cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)    plotMidText(cntrPt, parentPt, nodeTxt)    plotNode(firstStr, cntrPt, parentPt, decisionNode)    secondDict = myTree[firstStr]    plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            plotTree(secondDict[key], cntrPt, str(key))        else:            plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW            plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)            plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))    plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalDdef createPlot(inTree):    fig = plt.figure(1, facecolor='white')    fig.clf()    axprops = dict(xticks=[], yticks=[])    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    plotTree.totalW = float(getNumLeafs(inTree))    plotTree.totalD = float(getTreeDepth(inTree))    plotTree.x0ff = -0.5 / plotTree.totalW    plotTree.y0ff = 1.0    plotTree(inTree, (0.5, 1.0), '')    plt.show()
if __name__ == "__main__":    train_data, test_data, labels = load_data("dd.csv")    data_full = train_data[:]    labels_full = labels[:]    mode="post"    mode = "prev"    mode="post"    myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)    createPlot(myTree)    print(json.dumps(myTree, ensure_ascii=False, indent=4))

选择mode就可以分别得到三种树图