针对分类问题的决策树模型

来源:互联网 发布:婵真银杏 知乎 编辑:程序博客网 时间:2024/06/07 21:30

针对分类问题的决策树模型




以下代码片内容为周志华著《机器学习》图4.7的生成程序。


# 带后剪枝的决策树# 所有属性值均来自初始训练集,而不来自当前训练集# 样本集DataSet为二维列表,列表的每一行为一个样本,列表的最后一列为label# 属性名集AttrSet为一维列表# 属性值集AttrValSet为二维列表,列表的每一行为一个属性的值的集合def cal_shan_ent(DataSet):    import math    LabelCount = {}    for Sample in DataSet:        if Sample[-1] in LabelCount.keys():            LabelCount[Sample[-1]] += 1        else:            LabelCount[Sample[-1]] = 1    EntCal = 0    for Key in LabelCount:        Prob = float(LabelCount[Key])/len(DataSet)        EntCal = EntCal - Prob*math.log2(Prob)    return EntCaldef spl_data(DataSet,Axis,Value):    DataSetSpl = []    for Sam in DataSet:        if Sam[Axis]==Value:            SamRedu = Sam[0:Axis]            SamRedu.extend(Sam[Axis+1:])            DataSetSpl.append(SamRedu)    return DataSetSpldef find_all_val(DataSet,BestAttrIndex):    NumSample = len(DataSet)    AttrVal = []    for n in range(NumSample):        if DataSet[n][BestAttrIndex] in AttrVal:            pass        else:            AttrVal.extend([DataSet[n][BestAttrIndex]])    return AttrValdef maxgain_cal(DataSet):    GainCount = []    NumSample = len(DataSet)    NumAttr = len(DataSet[0]) - 1    for m in range(NumAttr):        GainCal = cal_shan_ent(DataSet)        AttrVal = find_all_val(DataSet, m)        for Val in AttrVal:            Data_Val = spl_data(DataSet,m,Val)            GainCal -= len(Data_Val)/NumSample*cal_shan_ent(Data_Val)        GainCount.append(GainCal)    return GainCount.index(max(GainCount))def major_label(DataSet):    LabelCount = {}    for Sample in DataSet:        if Sample[-1] in LabelCount.keys():            LabelCount[Sample[-1]] += 1        else:            LabelCount[Sample[-1]] = 1    return max(LabelCount) #当标记的各类别数目相同时,返回字典中最后一个类别def judge_attr_same(DataSet):    Flag = 'same'    NumCompare = len(DataSet[0]) - 1    for n in range(len(DataSet)):        if DataSet[0][0:NumCompare] != DataSet[n][0:NumCompare]:            Flag = 'not same'    if Flag=='same':        return True    else:        return Falsedef judge_label_same(DataSet):    Flag = 'same'    for n in range(len(DataSet)):        if DataSet[0][-1] != DataSet[n][-1]:            Flag = 'not same'    if Flag=='same':        return True    else:        return Falsedef create_tree(DataSet,AttrSet,AttrValSet):    # 如果所有样本的标记类别都相同,说明截止到当前结点,再作进一步划分,不会获得信息增益,返回该标记类别作为叶子    if judge_label_same(DataSet):        return DataSet[0][-1]    # 如果属性集为空,说明截止到当前结点(调用create_tree的Val_BestAttr),所有属性均分完,返回样本集中最多的标记类别作为叶子    if AttrSet==[]:        return major_label(DataSet)    # 如果样本集中所有样本的属性均相同,说明截止到当前结点,再作进一步划分,只会让决策树过拟合且面对某些样本时有缺失,返回样本集中最多的标记类别作为叶子    if judge_attr_same(DataSet):        return major_label(DataSet)    # 如果不满足结束递归的条件    # 寻找最优划分属性    BestAttrIndex = maxgain_cal(DataSet)    BestAttr = AttrSet[BestAttrIndex]    Tree = {BestAttr: {}}    # 在初始属性值的集合中寻找该属性的所有属性值    BestAttrVal = AttrValSet[BestAttrIndex]    # 遍历所有属性值:划分样本集,缩减属性集,递归    for Val in BestAttrVal:        DataSetDiv = spl_data(DataSet, BestAttrIndex, Val)        if DataSetDiv==[]:            Tree[BestAttr][Val] = major_label(DataSet)        else:            AttrSetShort = AttrSet[0:BestAttrIndex]            AttrSetShort.extend(AttrSet[BestAttrIndex+1:])            AttrValSetShort = AttrValSet[0:BestAttrIndex]            AttrValSetShort.extend(AttrValSet[BestAttrIndex+1:])            Tree[BestAttr][Val] = create_tree(DataSetDiv,AttrSetShort,AttrValSetShort)    return Treedef judge_label_sam(TreeCal,TestSam,AttrSet):    AttrJudged = list(TreeCal.keys())[0]    IndexJudged = AttrSet.index(AttrJudged)    if type(TreeCal[AttrJudged][TestSam[IndexJudged]])!=type({}):        return TreeCal[AttrJudged][TestSam[IndexJudged]]    TreeNext = TreeCal[AttrJudged][TestSam[IndexJudged]]    return judge_label_sam(TreeNext,TestSam,AttrSet)def judge_label_set(TreeCal,TestSet,AttrSet):    LabelJudged = []    for TestSam in TestSet:        LabelJudged.append(judge_label_sam(TreeCal, TestSam, AttrSet))    return LabelJudgeddef get_nodelines(TreeCal,KeysLines,KeysLineCurr):    import copy    # 若当前传入的树只有一层,将该结点放入Keys链条中,再将链条放入链条集合,并结束递归    AttrRoot = list(TreeCal.keys())[0]    AttrValRoot = list(TreeCal[AttrRoot].keys())    Flag = 'OneFlat'    for AttrVal in AttrValRoot:        if type(TreeCal[AttrRoot][AttrVal])==type({}):            Flag = 'NotOneFlat'    if Flag == 'OneFlat':        KeysLineCurr.extend([AttrRoot])        KeysLines.append(KeysLineCurr)        return KeysLines    # 若当前传入的树不止一层,先将该结点放入Keys链条中,将链条放入链条集合,再递归    KeysLineCurr.extend([AttrRoot])    KeysLines.append(KeysLineCurr)    for AttrVal in AttrValRoot:        KeysLineTemp = copy.deepcopy(KeysLineCurr)        if type(TreeCal[AttrRoot][AttrVal])==type({}):            KeysLineTemp.extend([AttrVal])            KeysLines = get_nodelines(TreeCal[AttrRoot][AttrVal], KeysLines, KeysLineTemp)    # 排序,链条长的在前面    KeysLines.sort(reverse=True)    return KeysLinesdef cal_label(KeysLine,DataSet,AttrSet):    # 找到链条对应的数据集    JudgeNum = int(len(KeysLine)/2)    if JudgeNum==-1:        DataSetNew = DataSet    else:        DataSetNew = []        for Sam in DataSet:            Flag = 'AllSame'            for m in range(JudgeNum):                Attr = KeysLine[m*2]                AttrVal = KeysLine[m*2+1]                if Sam[AttrSet.index(Attr)] != AttrVal:                    Flag = 'NotAllSame'            if Flag == 'AllSame':                DataSetNew.append(Sam)    # 找到数据集上的多数label    LabelCount = {}    for Sam in DataSetNew:        if Sam[-1] in LabelCount.keys():            LabelCount[Sam[-1]] += 1        else:            LabelCount[Sam[-1]] = 1    return max(LabelCount)def backword_cut(TreeCal,DataSet,TestSet,AttrSet):    import copy    # 获取当前树的非叶结点链条    KeysLines = get_nodelines(TreeCal, [], [])    for KeysLine in KeysLines:        # 获取list对应的DataSet,以最多类别替换最后一个结点,生成临时的新树        TreeTemp = copy.deepcopy(TreeCal)        TreeTempCh = TreeTemp[KeysLine[0]]        if len(KeysLine)>=3:            for m in range(1,len(KeysLine)-2):                TreeTempCh = TreeTempCh[KeysLine[m]]            TreeTempCh.pop(KeysLine[len(KeysLine)-2])            TreeTempCh[KeysLine[len(KeysLine)-2]] = cal_label( KeysLine, DataSet, AttrSet)        # 判断在TestSet上,新树的性能是否提升,提升就将旧树用新树替换        LabelNewTree = judge_label_set(TreeTemp, TestSet, AttrSet)        LabelOldTree = judge_label_set(TreeCal, TestSet, AttrSet)        LabelTrue = []        for Sam in TestSet:            LabelTrue.extend([Sam[-1]])        RNumNew,RNumOld = 0,0        for m in range(len(TestSet)):            if LabelTrue[m]==LabelNewTree[m]:                RNumNew += 1            if LabelTrue[m]==LabelOldTree[m]:                RNumOld += 1        if RNumNew>RNumOld:            TreeCal = TreeTemp    return TreeCal# 测试二import treePlotter as tpDataSet = [['青绿','蜷缩','浊响','清晰','凹陷','硬滑','好瓜'],           ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],           ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],            ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],            ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],            ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],            ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],            ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],            ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],            ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],]TestSet = [['青绿','蜷缩','沉闷','清晰','凹陷','硬滑','好瓜'],            ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],            ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],           ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],            ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],            ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],            ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],]AttrSet = ['色泽','根蒂','敲声','纹理','脐部','触感']AttrValSet = [find_all_val(DataSet,m) for m in range(len(AttrSet))]TreeCal = create_tree(DataSet,AttrSet,AttrValSet)TreeACut = backword_cut(TreeCal,DataSet,TestSet,AttrSet)print('\n','DataSet = \n',DataSet,'\n')print('\n','TestSet = \n',TestSet,'\n')print('\n','AttrSet = \n',AttrSet,'\n')print('\n','AttrValSet = \n',AttrValSet,'\n')print('\n','TreeCal = \n',TreeCal,'\n')print('\n','TreeACut = \n',TreeACut,'\n')tp.createPlot(TreeCal)tp.createPlot(TreeACut)import matplotlib.pyplot as plt# 显示中文,需要在绘图函数中加上fontproperties=ChFontfrom pylab import *ChFont = matplotlib.font_manager.FontProperties(fname='C:\Windows\Fonts\STFANGSO.TTF')decisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-")def getNumLeafs(myTree):    numLeafs = 0    firstStr = list(myTree.keys())[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes            numLeafs += getNumLeafs(secondDict[key])        else:            numLeafs += 1    return numLeafsdef getTreeDepth(myTree):    maxDepth = 0    firstStr = list(myTree.keys())[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes            thisDepth = 1 + getTreeDepth(secondDict[key])        else:            thisDepth = 1        if thisDepth > maxDepth: maxDepth = thisDepth    return maxDepthdef plotNode(nodeTxt, centerPt, parentPt, nodeType):    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',                            xytext=centerPt, textcoords='axes fraction',                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args,fontproperties=ChFont,fontsize=12)def plotMidText(cntrPt, parentPt, txtString):    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]    yMid = (parentPt[1] - cntrPt[1]) / 1.9 + cntrPt[1]    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30,fontproperties=ChFont,fontsize=12)def plotTree(myTree, parentPt, nodeTxt):  # if the first key tells you what feat was split on    numLeafs = getNumLeafs(myTree)  # this determines the x width of this tree    depth = getTreeDepth(myTree)    firstStr = list(myTree.keys())[0]  # the text label for this node should be this    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)    plotMidText(cntrPt, parentPt, nodeTxt)    plotNode(firstStr, cntrPt, parentPt, decisionNode)    secondDict = myTree[firstStr]    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD    for key in secondDict.keys():        if type(secondDict[                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes            plotTree(secondDict[key], cntrPt, str(key))  # recursion        else:  # it's a leaf node print the leaf node            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD# if you do get a dictonary you know it's a tree, and the first element will be another dictdef createPlot(inTree):    fig = plt.figure(1, facecolor='white')    fig.clf()    axprops = dict(xticks=[], yticks=[])    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # no ticks    # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses    plotTree.totalW = float(getNumLeafs(inTree))    plotTree.totalD = float(getTreeDepth(inTree))    plotTree.xOff = -0.5 / plotTree.totalW;    plotTree.yOff = 1.0;    plotTree(inTree, (0.5, 1.0), '')    plt.show()


原创粉丝点击