ID3决策树的算法原理与python实现
来源:互联网 发布:mac解压rar百度云 编辑:程序博客网 时间:2024/05/22 13:11
1. 引言
决策树从本质上是从训练数据集上训练处一组分类规则,完全依据训练数据,所得规则容易发生过拟合,这也是决策树的缺点,不过可以通过决策树的剪枝,来提高决策树的泛化能力。由此,决策树的创建可包括三部分:特征选择、决策树的生成以及决策树的剪枝;决策树的应用包括:分类、回归以及特征选择。
决策树最经典的算法包括:ID3、C4.5以及CART算法,ID3与C4.5算法相似,C4.5在特征选择时选用的信息准则是信息增益比,而ID3用的是信息增益;因为信息增益偏向于选择具有较多可能取值的特征(例如,某一特征具有5个可能取值,其信息增益会比具有2个特征取值的信息增益大)。
2.主要内容
- 基于信息论的特征选择(python实现信息增益的计算)
- 决策树的生成
- ID3算法的python实现。
3. 基于信息论的特征选择
注意:熵表示随机变量的不确定性,熵值越大表示随机变量含有的信息越少,变量的不确定性越大。
1) 香侬定义一个数据的信息可按下式计算 (此处是以2为底的对数)
2)熵表示一个数据集合信息的期望,可按下式计算:(该式不理解,可想象下,求变量期望的公式,
3)特征
上式中,设训练数据集为
4)python实现
def calcShannonEnt(dataset):#计算熵 numSamples = len(dataset) labelCounts = {} for allFeatureVector in dataset: currentLabel = allFeatureVector[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 entropy = 0.0 for key in labelCounts: property = float(labelCounts[key])/numSamples entropy -= property * log(property,2) return entropydef BestFeatToGetSubdataset(dataset): #下边这句实现:除去最后一列类别标签列剩余的列数即为特征个数 numFeature = len(dataset[0]) - 1 baseEntropy = calcShannonEnt(dataset) bestInfoGain = 0.0; bestFeature = -1 for i in range(numFeature):#i表示该函数传入的数据集中每个特征 # 下边这句实现抽取特征i在数据集中的所有取值 feat_i_values = [example[i] for example in dataset] uniqueValues = set(feat_i_values) feat_i_entropy = 0.0 for value in uniqueValues: subDataset = getSubDataset(dataset,i,value) #下边这句计算pi,实现计算信息增益最大的特征 prob_i = len(subDataset)/float(len(dataset)) feat_i_entropy += prob_i * calcShannonEnt(subDataset) infoGain_i = baseEntropy - feat_i_entropy if (infoGain_i > bestInfoGain): bestInfoGain = infoGain_i bestFeature = i return bestFeature
4.决策树生成
决策树生成可用下边的流程图表示:
5. ID3算法python实现代码
# -*- coding: utf-8 -*-from math import logimport operatorimport pickle'''输入:原始数据集、子数据集(最后一列为类别标签,其他为特征列)功能:计算原始数据集、子数据集(某一特征取值下对应的数据集)的香农熵输出:float型数值(数据集的熵值)'''def calcShannonEnt(dataset): numSamples = len(dataset) labelCounts = {} for allFeatureVector in dataset: currentLabel = allFeatureVector[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 entropy = 0.0 for key in labelCounts: property = float(labelCounts[key])/numSamples entropy -= property * log(property,2) return entropy'''输入:无功能:封装原始数据集输出:数据集、特征标签''' def creatDataSet(): dataset = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,0,'no']] labels = ['no surfacing','flippers'] return dataset,labels'''输入:数据集、数据集中的某一特征所在列的索引、该特征某一可能取值(例如,(原始数据集、0,1 ))功能:取出在该特征取值下的子数据集(子集不包含该特征)输出:子数据集'''def getSubDataset(dataset,colIndex,value): subDataset = [] #用于存储子数据集 for rowVector in dataset: if rowVector[colIndex] == value: #下边两句实现抽取除第colIndex列特征的其他特征取值 subRowVector = rowVector[:colIndex] subRowVector.extend(rowVector[colIndex+1:]) #将抽取的特征行添加到特征子数据集中 subDataset.append(subRowVector) return subDataset'''输入:数据集功能:选择最优的特征,以便得到最优的子数据集(可简单的理解为特征在决策树中的先后顺序)输出:最优特征在数据集中的列索引'''def BestFeatToGetSubdataset(dataset): #下边这句实现:除去最后一列类别标签列剩余的列数即为特征个数 numFeature = len(dataset[0]) - 1 baseEntropy = calcShannonEnt(dataset) bestInfoGain = 0.0; bestFeature = -1 for i in range(numFeature):#i表示该函数传入的数据集中每个特征 # 下边这句实现抽取特征i在数据集中的所有取值 feat_i_values = [example[i] for example in dataset] uniqueValues = set(feat_i_values) feat_i_entropy = 0.0 for value in uniqueValues: subDataset = getSubDataset(dataset,i,value) #下边这句计算pi prob_i = len(subDataset)/float(len(dataset)) feat_i_entropy += prob_i * calcShannonEnt(subDataset) infoGain_i = baseEntropy - feat_i_entropy if (infoGain_i > bestInfoGain): bestInfoGain = infoGain_i bestFeature = i return bestFeature'''输入:子数据集的类别标签列功能:找出该数据集个数最多的类别输出:子数据集中个数最多的类别标签''' def mostClass(ClassList): classCount = {} for class_i in ClassList: if class_i not in classCount.keys(): classCount[class_i] = 0 classCount[class_i] += 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),reverse = True) return sortedClassCount[0][0]'''输入:数据集,特征标签功能:创建决策树(直观的理解就是利用上述函数创建一个树形结构)输出:决策树(用嵌套的字典表示)''' def creatTree(dataset,labels): classList = [example[-1] for example in dataset] #判断传入的dataset中是否只有一种类别,是,返回该类别 if classList.count(classList[0]) == len(classList): return classList[0] #判断是否遍历完所有的特征,是,返回个数最多的类别 if len(dataset[0]) == 1: return mostClass(classList) #找出最好的特征划分数据集 bestFeat = BestFeatToGetSubdataset(dataset) #找出最好特征对应的标签 bestFeatLabel = labels[bestFeat] #搭建树结构 myTree = {bestFeatLabel:{}} del (labels[bestFeat]) #抽取最好特征的可能取值集合 bestFeatValues = [example[bestFeat] for example in dataset] uniqueBestFeatValues = set(bestFeatValues) for value in uniqueBestFeatValues: #取出在该最好特征的value取值下的子数据集和子标签列表 subDataset = getSubDataset(dataset,bestFeat,value) subLabels = labels[:] #递归创建子树 myTree[bestFeatLabel][value] = creatTree(subDataset,subLabels) return myTree'''输入:测试特征数据功能:调用训练决策树对测试数据打上类别标签输出:测试特征数据所属类别''' def classify(inputTree,featlabels,testFeatValue): firstStr = inputTree.keys()[0] secondDict = inputTree[firstStr] featIndex = featlabels.index(firstStr) for firstStr_value in secondDict.keys(): if testFeatValue[featIndex] == firstStr_value: if type(secondDict[firstStr_value]).__name__ == 'dict': classLabel = classify(secondDict[firstStr_value],featlabels,testFeatValue) else: classLabel = secondDict[firstStr_value] return classLabel '''输入:训练树,存储的文件名功能:训练树的存储输出:'''def storeTree(trainTree,filename): fw = open(filename,'w') pickle.dump(trainTree,fw) fw.close()def grabTree(filename): fr = open(filename) return pickle.load(fr)if __name__ == '__main__': dataset,labels = creatDataSet() storelabels = labels[:]#复制label trainTree = creatTree(dataset,labels) classlabel = classify(trainTree,storelabels,[0,1]) print classlabel
运行结果:
In [1]:runfile('E:/python/ml/dtrees/trees.py', wdir='E:/python/ml/dtrees')noIn [2]:
参考文献:
统计学习方法,李航
machine learning in action 中文版
- ID3决策树的算法原理与python实现
- 决策树分类ID3算法的Python实现
- 决策树ID3算法的python实现
- 分类算法-----决策树(ID3)算法原理和Python实现
- python实现决策树ID3算法
- Python实现ID3算法决策树
- Python实现决策树算法ID3
- 决策树ID3 算法python实现
- 决策树ID3的Python实现
- 机器学习之决策树(ID3)算法与Python实现
- ID3决策树原理分析及python实现
- 决策树算法原理及JAVA实现(ID3)
- ID3决策树算法原理及C++实现
- 决策树ID3算法原理
- 决策树的ID3算法实现(Python版)
- 【机器学习】决策树-ID3算法的Python实现
- 机器学习算法的Python实现 (2):ID3决策树
- 决策树之ID3算法实现(python)
- git代码回滚
- kotlin之新手入门(1)
- java使用ssh调用shell命令获取KVM数据(KVM需要通过libvirt管理)
- 教你如何用信用卡取现免手续费!
- windows时区数据
- ID3决策树的算法原理与python实现
- WebView进阶(一) :Android WebView与JS互相调用
- window对象方法
- jdk-AbstractQueuedSynchronizer(三)
- linux服务器安全优化设置
- react native 适配机顶盒、智能电视 遥控器解决焦点问题
- Python3.X之面向对象高级编程笔记
- Java中IO流知识总结
- 带罗盘按扭(八个方向按扭)的摄像头云台控件