针对分类问题的决策树模型
来源:互联网 发布:婵真银杏 知乎 编辑:程序博客网 时间:2024/06/07 21:30
针对分类问题的决策树模型
以下代码片内容为周志华著《机器学习》图4.7的生成程序。
# 带后剪枝的决策树# 所有属性值均来自初始训练集,而不来自当前训练集# 样本集DataSet为二维列表,列表的每一行为一个样本,列表的最后一列为label# 属性名集AttrSet为一维列表# 属性值集AttrValSet为二维列表,列表的每一行为一个属性的值的集合def cal_shan_ent(DataSet): import math LabelCount = {} for Sample in DataSet: if Sample[-1] in LabelCount.keys(): LabelCount[Sample[-1]] += 1 else: LabelCount[Sample[-1]] = 1 EntCal = 0 for Key in LabelCount: Prob = float(LabelCount[Key])/len(DataSet) EntCal = EntCal - Prob*math.log2(Prob) return EntCaldef spl_data(DataSet,Axis,Value): DataSetSpl = [] for Sam in DataSet: if Sam[Axis]==Value: SamRedu = Sam[0:Axis] SamRedu.extend(Sam[Axis+1:]) DataSetSpl.append(SamRedu) return DataSetSpldef find_all_val(DataSet,BestAttrIndex): NumSample = len(DataSet) AttrVal = [] for n in range(NumSample): if DataSet[n][BestAttrIndex] in AttrVal: pass else: AttrVal.extend([DataSet[n][BestAttrIndex]]) return AttrValdef maxgain_cal(DataSet): GainCount = [] NumSample = len(DataSet) NumAttr = len(DataSet[0]) - 1 for m in range(NumAttr): GainCal = cal_shan_ent(DataSet) AttrVal = find_all_val(DataSet, m) for Val in AttrVal: Data_Val = spl_data(DataSet,m,Val) GainCal -= len(Data_Val)/NumSample*cal_shan_ent(Data_Val) GainCount.append(GainCal) return GainCount.index(max(GainCount))def major_label(DataSet): LabelCount = {} for Sample in DataSet: if Sample[-1] in LabelCount.keys(): LabelCount[Sample[-1]] += 1 else: LabelCount[Sample[-1]] = 1 return max(LabelCount) #当标记的各类别数目相同时,返回字典中最后一个类别def judge_attr_same(DataSet): Flag = 'same' NumCompare = len(DataSet[0]) - 1 for n in range(len(DataSet)): if DataSet[0][0:NumCompare] != DataSet[n][0:NumCompare]: Flag = 'not same' if Flag=='same': return True else: return Falsedef judge_label_same(DataSet): Flag = 'same' for n in range(len(DataSet)): if DataSet[0][-1] != DataSet[n][-1]: Flag = 'not same' if Flag=='same': return True else: return Falsedef create_tree(DataSet,AttrSet,AttrValSet): # 如果所有样本的标记类别都相同,说明截止到当前结点,再作进一步划分,不会获得信息增益,返回该标记类别作为叶子 if judge_label_same(DataSet): return DataSet[0][-1] # 如果属性集为空,说明截止到当前结点(调用create_tree的Val_BestAttr),所有属性均分完,返回样本集中最多的标记类别作为叶子 if AttrSet==[]: return major_label(DataSet) # 如果样本集中所有样本的属性均相同,说明截止到当前结点,再作进一步划分,只会让决策树过拟合且面对某些样本时有缺失,返回样本集中最多的标记类别作为叶子 if judge_attr_same(DataSet): return major_label(DataSet) # 如果不满足结束递归的条件 # 寻找最优划分属性 BestAttrIndex = maxgain_cal(DataSet) BestAttr = AttrSet[BestAttrIndex] Tree = {BestAttr: {}} # 在初始属性值的集合中寻找该属性的所有属性值 BestAttrVal = AttrValSet[BestAttrIndex] # 遍历所有属性值:划分样本集,缩减属性集,递归 for Val in BestAttrVal: DataSetDiv = spl_data(DataSet, BestAttrIndex, Val) if DataSetDiv==[]: Tree[BestAttr][Val] = major_label(DataSet) else: AttrSetShort = AttrSet[0:BestAttrIndex] AttrSetShort.extend(AttrSet[BestAttrIndex+1:]) AttrValSetShort = AttrValSet[0:BestAttrIndex] AttrValSetShort.extend(AttrValSet[BestAttrIndex+1:]) Tree[BestAttr][Val] = create_tree(DataSetDiv,AttrSetShort,AttrValSetShort) return Treedef judge_label_sam(TreeCal,TestSam,AttrSet): AttrJudged = list(TreeCal.keys())[0] IndexJudged = AttrSet.index(AttrJudged) if type(TreeCal[AttrJudged][TestSam[IndexJudged]])!=type({}): return TreeCal[AttrJudged][TestSam[IndexJudged]] TreeNext = TreeCal[AttrJudged][TestSam[IndexJudged]] return judge_label_sam(TreeNext,TestSam,AttrSet)def judge_label_set(TreeCal,TestSet,AttrSet): LabelJudged = [] for TestSam in TestSet: LabelJudged.append(judge_label_sam(TreeCal, TestSam, AttrSet)) return LabelJudgeddef get_nodelines(TreeCal,KeysLines,KeysLineCurr): import copy # 若当前传入的树只有一层,将该结点放入Keys链条中,再将链条放入链条集合,并结束递归 AttrRoot = list(TreeCal.keys())[0] AttrValRoot = list(TreeCal[AttrRoot].keys()) Flag = 'OneFlat' for AttrVal in AttrValRoot: if type(TreeCal[AttrRoot][AttrVal])==type({}): Flag = 'NotOneFlat' if Flag == 'OneFlat': KeysLineCurr.extend([AttrRoot]) KeysLines.append(KeysLineCurr) return KeysLines # 若当前传入的树不止一层,先将该结点放入Keys链条中,将链条放入链条集合,再递归 KeysLineCurr.extend([AttrRoot]) KeysLines.append(KeysLineCurr) for AttrVal in AttrValRoot: KeysLineTemp = copy.deepcopy(KeysLineCurr) if type(TreeCal[AttrRoot][AttrVal])==type({}): KeysLineTemp.extend([AttrVal]) KeysLines = get_nodelines(TreeCal[AttrRoot][AttrVal], KeysLines, KeysLineTemp) # 排序,链条长的在前面 KeysLines.sort(reverse=True) return KeysLinesdef cal_label(KeysLine,DataSet,AttrSet): # 找到链条对应的数据集 JudgeNum = int(len(KeysLine)/2) if JudgeNum==-1: DataSetNew = DataSet else: DataSetNew = [] for Sam in DataSet: Flag = 'AllSame' for m in range(JudgeNum): Attr = KeysLine[m*2] AttrVal = KeysLine[m*2+1] if Sam[AttrSet.index(Attr)] != AttrVal: Flag = 'NotAllSame' if Flag == 'AllSame': DataSetNew.append(Sam) # 找到数据集上的多数label LabelCount = {} for Sam in DataSetNew: if Sam[-1] in LabelCount.keys(): LabelCount[Sam[-1]] += 1 else: LabelCount[Sam[-1]] = 1 return max(LabelCount)def backword_cut(TreeCal,DataSet,TestSet,AttrSet): import copy # 获取当前树的非叶结点链条 KeysLines = get_nodelines(TreeCal, [], []) for KeysLine in KeysLines: # 获取list对应的DataSet,以最多类别替换最后一个结点,生成临时的新树 TreeTemp = copy.deepcopy(TreeCal) TreeTempCh = TreeTemp[KeysLine[0]] if len(KeysLine)>=3: for m in range(1,len(KeysLine)-2): TreeTempCh = TreeTempCh[KeysLine[m]] TreeTempCh.pop(KeysLine[len(KeysLine)-2]) TreeTempCh[KeysLine[len(KeysLine)-2]] = cal_label( KeysLine, DataSet, AttrSet) # 判断在TestSet上,新树的性能是否提升,提升就将旧树用新树替换 LabelNewTree = judge_label_set(TreeTemp, TestSet, AttrSet) LabelOldTree = judge_label_set(TreeCal, TestSet, AttrSet) LabelTrue = [] for Sam in TestSet: LabelTrue.extend([Sam[-1]]) RNumNew,RNumOld = 0,0 for m in range(len(TestSet)): if LabelTrue[m]==LabelNewTree[m]: RNumNew += 1 if LabelTrue[m]==LabelOldTree[m]: RNumOld += 1 if RNumNew>RNumOld: TreeCal = TreeTemp return TreeCal# 测试二import treePlotter as tpDataSet = [['青绿','蜷缩','浊响','清晰','凹陷','硬滑','好瓜'], ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'], ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'], ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'], ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'], ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'], ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'], ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'], ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'], ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],]TestSet = [['青绿','蜷缩','沉闷','清晰','凹陷','硬滑','好瓜'], ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'], ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'], ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'], ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'], ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'], ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],]AttrSet = ['色泽','根蒂','敲声','纹理','脐部','触感']AttrValSet = [find_all_val(DataSet,m) for m in range(len(AttrSet))]TreeCal = create_tree(DataSet,AttrSet,AttrValSet)TreeACut = backword_cut(TreeCal,DataSet,TestSet,AttrSet)print('\n','DataSet = \n',DataSet,'\n')print('\n','TestSet = \n',TestSet,'\n')print('\n','AttrSet = \n',AttrSet,'\n')print('\n','AttrValSet = \n',AttrValSet,'\n')print('\n','TreeCal = \n',TreeCal,'\n')print('\n','TreeACut = \n',TreeACut,'\n')tp.createPlot(TreeCal)tp.createPlot(TreeACut)import matplotlib.pyplot as plt# 显示中文,需要在绘图函数中加上fontproperties=ChFontfrom pylab import *ChFont = matplotlib.font_manager.FontProperties(fname='C:\Windows\Fonts\STFANGSO.TTF')decisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-")def getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[ key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafsdef getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[ key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepthdef plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args,fontproperties=ChFont,fontsize=12)def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1]) / 1.9 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30,fontproperties=ChFont,fontsize=12)def plotTree(myTree, parentPt, nodeTxt): # if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) # this determines the x width of this tree depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] # the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD for key in secondDict.keys(): if type(secondDict[ key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key], cntrPt, str(key)) # recursion else: # it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD# if you do get a dictonary you know it's a tree, and the first element will be another dictdef createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticks # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5, 1.0), '') plt.show()
阅读全文
0 0
- 针对分类问题的决策树模型
- 针对分类问题的单隐层前馈神经网络模型
- 针对二分类问题的对数几率模型
- 针对二分类问题的线性判别分析模型
- 针对二分类问题的支持向量机模型
- R语言决策树分类模型
- 分类模型与算法--决策树
- 针对回归问题的广义线性模型
- 单一决策树与集成模型(随机森林分类器、梯度提升决策树)的比较
- 基于决策树的分类
- 分类决策树的理解
- 初探NO.2—离散分类问题&决策树的启示
- MLiA 贝叶斯分类总结及决策树的一个小问题
- 数据挖掘之决策树分类模型
- 文本分类——决策树模型
- 分类(1):决策树与模型评估
- 数据挖掘之决策树分类模型(…
- 基于决策树的模型
- 设置本地域名解析
- 第7讲项目3-计算一员工的周工资
- 小范围排序
- 真正准确的“两个日期相差多少天”函数
- 编译安装PHP7并安装Redis扩展Swoole扩展
- 针对分类问题的决策树模型
- server.CreateObject("ADODB.Stream") 方法说明
- 安装theano模块
- jeesite Integer类型变成String类型的原因
- CRNN论文翻译——中英文对照
- centos 6.8 + pgsql 9.6 + file_fdw
- Android6.0加载Xutils-2.6.14.jar出现retry error, curr request is null 解决办法
- [NOIP模拟]BOX(推箱子)-BFS
- Python字符串格式化符号