决策树的算法描述和源码分析
来源:互联网 发布:人工智能伏羲觉醒种子 编辑:程序博客网 时间:2024/06/05 04:13
前言:
决策树是一种分类算法,简单而有效。它与KNN不同,KNN可以简单地得到结果,但是无法理解数据本身的含义。决策树的优势在于可以以树状的形式理解数据。现在我们希望可以判断邮件箱中的邮件属性,即是必须要处理的,还是属于垃圾邮件。决策树的做法如图一所示
我们将根据邮件数据中的两个属性进行判断,首先判断邮件域名地址,如果结果为是,则此封邮件属于无聊时需要阅读的邮件,如果结果为否,则继续判断数据的第二个属性即是否包含单词”曲棍球”,若结果为是,则此封邮件为需要及时处理的朋友邮件,为否则属于无需阅读的垃圾邮件。
可以看到得到的决策树简洁明了,如果配合adaboost、以及级联使用,可以更进一步地改善分类效果,决策树也是处理数据中非常常用的一种算法。决策树的算法有很多种,这里主要讲的是ID3算法,接下来将是算法描述和源码分析部分。
算法描述
划分数据集的大原则是:将无序的数据集变得更加有序。基于此原则,我们首先得确定何为有序数据,何为无序数据。
这里引入一个概念,香农熵或者简称为熵:它是由天才科学家香农提出的,表示随机变量不确定的度量。
计算公式如下:
其中它表示了待分类的事物可能划分在多个分类中,则信息xi的定义即为l(xi),其中p(xi)是选择该分类的概率。很显然可以看到,熵即为信息的期望。当各分类的概率相同时,熵最大,当只有一种可能时,熵最小,此时熵为0。
根据我们划分数据的原则,我们希望选择特征的时候能把数据集的熵尽可能降低,进而使数据变得有序,更好地处理。根据这种自然而然的想法, 就有了下个定义,信息增益:
看起来经验条件熵非常复杂, 其实很简单,就是根据特征将一个大的数据集划分为几个小的数据集,然后分别计算小数据集的熵,然后取均值。
ID3算法的核心是在决策树各个结点上应用信息增益准则选择特征,递归地构建决策树。具体方法是:从根结点开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子结点:再对子结点递归地调用以上方法,构建决策树,最后所得到的决策树如上图一所示。
得到决策树后就可以对新来的数据进行分类了,具体做法是,从根节点开始,测试待分类项中相应的特征属性,并按照其值选择输出分支,直到到达叶子节点,将叶子节点存放的类别作为决策结果。
源码分析
对一个算法描述,难免由于不同人的不同理解产生误差,所以我个人认为了解一个算法的最好途径是通过代码——这种不会产生误差的方式。因为我已经在代码中加以了注释,就不再对各个部分进行讲解了,如果有问题欢迎讨论O(∩_∩)O~。
matlab代码
function readtextreadname='F:\机器学习实战\机器学习实战\源码+数据\Ch03\lenses.txt';text=textread(readname,'%s','delimiter','');%一行行地读取数据num=zeros(size(text,1),5);%开辟空间存储特征for i=1:size(text,1)text{i}=regexp(text{i}, '\w+','match');%采用正则表达式解析字符串if(size(text{i},2))==6 text{i}{5}=strcat(text{i}{5},text{i}{6});%因为最后一个特征有no lenses、soft、hard %为了方便处理将第一种情况的两个字符串合在一起 text{i}(6)=[];end% 因为matlab没有类似python中的set操作,无法删除字符数组中重复出现的字符串,所以只好事先将字符转为数字switch text{i}{1} case 'young' num(i,1)=1; case 'pre' num(i,1)=2; case 'presbyopic' num(i,1)=3;endswitch text{i}{2} case 'myope' num(i,2)=1; case 'hyper' num(i,2)=2;endswitch text{i}{3} case 'no' num(i,3)=1; case 'yes' num(i,3)=2;endswitch text{i}{4} case 'reduced' num(i,4)=1; case 'normal' num(i,4)=2;endswitch text{i}{5} case 'nolenses' num(i,5)=1; case 'hard' num(i,5)=2; case 'soft' num(i,5)=3;endendlabel={'age','prescript','astigmatic','tearRate'};mytree=CreateTree(num,label);function mytree=CreateTree(dataset,label)% function :构建决策树% Input :dataset :数据集% label:类别标签% Output :输出记录决策树结点的结构体mytree.FatherNode='root';mytree.property='stem';databest=ChooseBestFeatureToSplit(dataset);mytree.NodeName=label{databest};Feature=unique(dataset(:,databest));FatherNode=label{databest};label(databest)=[];for i=1:length(Feature) mytree=[CreateChildTree(splitDataSet(dataset,databest,Feature(i)),i,FatherNode,label);mytree];endfunction mytree=CreateChildTree(dataset,property,FatherNode,label)mytree.FatherNode=FatherNode;mytree.property=property;if length(unique(dataset(:,end)))==1 %如果剩下的所有样本皆为同一类则跳出 mytree.NodeName=unique(dataset(:,end)); return endif size(dataset,2)==1 %分割完所有的特征,则选取剩下样本中类别数最多的一类返回 mytree.NodeName=MajorClass(dataset); returnenddatabest=ChooseBestFeatureToSplit(dataset);mytree.NodeName=label{databest};Feature=unique(dataset(:,databest));FatherNode=label{databest};label(databest)=[];for i=1:length(Feature) mytree=[CreateChildTree(splitDataSet(dataset,databest,Feature(i)),i,FatherNode,label);mytree];endfunction class=MajorClass(classlist)% function : 选择类别中数量最多的类别返回% Input : 给定的类别列表%% Output :返回最后的类别 tabClasslist=tabulate(classlist);[~,index]=max(tabClasslist(:,3));class=tabClasslist(1,index);function BestFeature=ChooseBestFeatureToSplit(dataset)% funciton: 选择最好的特征来分割样本% Input : 一个表示样本的m*N+1矩阵,前N列为特征,最后一列为类别% % Output : 返回最好的特征索引BestFeature=0;BestEnt=1/eps;for i=1:size(dataset,2)-1 Feature=dataset(:,i); FeatureClass=unique(Feature); newEnt=0; for j=1:length(FeatureClass) subData=splitDataSet(dataset,i,FeatureClass(j)); prob=length(subData)/length(Feature); newEnt=newEnt+prob*CalShannonEnt(subData(:,end)); end if newEnt<BestEnt BestEnt=newEnt; BestFeature=i; endendfunction Ent=CalShannonEnt(datavector)% function: 计算香农熵% Input datavector: 代表类别的向量% % output Ent:返回计算的熵值% tabclass=tabulate(datavector);tabclass(:,3)=tabclass(:,3)/100;tabclass(tabclass(:,2)==0,:)=[];Ent=-sum(tabclass(:,3).*log2(tabclass(:,3)));function reData=splitDataSet(dataset,axis,value)% function: 根据给定的特征划分样本% Input dataset:给定的样本集,为一个n*p矩阵% axis :给定的划分特征% value :需符合的条件% Output reData :根据给定条件所得到的划分index=find(dataset(:,axis)==value);dataset(:,axis)=[];reData=dataset(index,:);
Python代码
'''Created on Oct 12, 2010Decision Tree Source Code for Machine Learning in Action Ch. 3@author: Peter Harrington'''from math import logimport operatordef createDataSet(): #创建新的数据集 dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] labels = ['no surfacing','flippers'] #change to discrete values return dataSet, labelsdef calcShannonEnt(dataSet): #计算香农熵 numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: #the the number of unique elements and their occurance currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries #统计每一个类出现的概率 shannonEnt -= prob * log(prob,2) #根据公式计算 return shannonEntdef splitDataSet(dataSet, axis, value): #划分数据集 retDataSet = [] for featVec in dataSet: if featVec[axis] == value:#把处于同一值下的所有样本取出 reducedFeatVec = featVec[:axis] #chop out axis used for splitting reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSetdef chooseBestFeatureToSplit(dataSet): #根据信息增益准则,选择最好的特征进行分割 numFeatures = len(dataSet[0]) - 1 #the last column is used for the labels baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0; bestFeature = -1 #计算以不同特征的不同值进行分割所得到的信息增益,选最大的那个进行分割 for i in range(numFeatures): #iterate over all the features featList = [example[i] for example in dataSet]#create a list of all the examples of this feature uniqueVals = set(featList) #get a set of unique values newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropy if (infoGain > bestInfoGain): #compare this to the best gain so far bestInfoGain = infoGain #if better than current best, set to best bestFeature = i return bestFeature #returns an integerdef majorityCnt(classList): #如果特征筛选完毕,则选择所余样本中类别数最大的一个作为叶子节点的类别 classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]def createTree(dataSet,labels): #构造字典树 classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList): return classList[0]#stop splitting when all of the classes are equal if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet)#得到最好的分类特征 bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] #copy all of labels, so trees don't mess up existing labels myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels) return myTree def classify(inputTree,featLabels,testVec): #进行分类 firstStr = inputTree.keys()[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) key = testVec[featIndex] valueOfFeat = secondDict[key] if isinstance(valueOfFeat, dict): #反复迭代,直到抵达叶子节点 classLabel = classify(valueOfFeat, featLabels, testVec) else: classLabel = valueOfFeat return classLabeldef GetData(): #获取text文本中的数据 fr=open('lenses.txt') lenses=[rows.strip().split('\t') for rows in fr] lenseslabel=['age','prescript','astigmatic','tearRate'] return lenses,lenseslabel'''def storeTree(inputTree,filename): import pickle fw = open(filename,'w') pickle.dump(inputTree,fw) fw.close()def grabTree(filename): import pickle fr = open(filename) return pickle.load(fr)'''myData,label=createDataSet()labeltemp=label[:]lenses,lenseslabel=GetData()tree=createTree(lenses,lenseslabel)print tree
包括数据集的完整链接,点这
- 决策树的算法描述和源码分析
- 分析决策树算法和逻辑回归算法的准确率问题
- Mahout决策树算法源码分析(1)
- Mahout决策树算法源码分析(2)
- Mahout决策树算法源码分析(3)
- Mahout决策树算法源码分析(4)
- Mahout决策树算法源码分析(2)
- 决策树理论、C4.5源码分析及AdaBoost算法的提升改造
- Mahout决策树算法源码分析(3-1)建树实战
- MS决策树分析算法
- 决策树ID3和C4.5算法Python实现源码
- 决策树ID3和C4.5算法Python实现源码
- 决策树ID3和C4.5算法Python实现源码
- 决策树的剪枝和CART算法
- Spark中决策树源码分析
- 数据结构与算法分析:C语言描述(pdf+源码+答案)
- 数据分析算法----1决策树
- 《机器学习实战》决策树(ID3算法)的分析与实现
- Java学习心得体会。
- ListView中包含Button情况下焦点事件的获取
- 多线程继承Thread和实现runnable的区别
- 欢迎使用CSDN-markdown编辑器
- Homebrew 终于进入1.x
- 决策树的算法描述和源码分析
- Proguard笔记
- 扩展BaseAdapter实现不存储列表项的ListView
- mysqlslap
- Hadoop0.20.2版本在Ubuntu下安装和配置
- 第四章 条件结构
- 万用表的使用
- win7无法安装msi解决办法
- 第五章 循环结构