id3算法(python代码)
来源:互联网 发布:mysql 创建覆盖索引 编辑:程序博客网 时间:2024/06/08 09:33
1. 该python实现没有考虑到overfitting。
- # coding=utf-8
- from numpy import *
- from math import log
- #下面的函数用来计算香农熵 H=sum(-p(xi)log(p(xi))) 其中xi指的是每种类别所占的比例
- def calcShannonEnt(dataSet):
- numEntries=len(dataSet)#数据集的行数
- labelCounts={}#数据集的类别标签
- for featVec in dataSet:
- currentLabel=featVec[-1]#读取数据集的类别
- if currentLabel not in labelCounts.keys():
- labelCounts[currentLabel]=0
- labelCounts[currentLabel]+=1
- shannonEnt=0.0#香农熵
- for key in labelCounts:
- prob=float(labelCounts[key])/numEntries
- shannonEnt-=prob*log(prob,2)
- return shannonEnt#返回香农熵
- #创建数据集,返回数据集合和标签列表,该标签是人工标签,仅仅是表明真实类别的名字而已
- def createDataSet():
- dataSet=[
- [1,1,'yes'],
- [1,1,'yes'],
- [1,0,'no'],
- [0,1,'no'],
- [0,1,'no']
- ]
- labels=['LabelOne','LableTwo']
- return dataSet,labels
- #划分数据集,axis是特征值,value是该特征值对应的value
- #将dataSet中特征值等于value的样本点筛选出来
- def splitDataSet(dataSet,axis,value):
- retDataSet=[]
- for featVec in dataSet:
- if featVec[axis]==value:
- reducedFeatVec=featVec[:axis]
- reducedFeatVec=featVec[axis+1:]
- retDataSet.append(reducedFeatVec)
- return retDataSet
- #遍历数据集找到最好的数据集划分方式
- #H=sum(-p(xi)log(p(xi)))
- def chooseBestFeatureToSplit(dataSet):
- numFeatures=len(dataSet[0])-1 #numFeature是样本xi的特征数目,请注意最后一个是类别标签
- baseEntroy=calcShannonEnt(dataSet)#计算原始数据集的香农熵
- bestInfoGain=0.0;#最大信息增益
- bestFeature=-1#采用哪个特征分裂
- #从现有的所有特征之中找到最合适的分裂,是的信息的增益最大
- for i in range(numFeatures):
- featList=[example[i] for example in dataSet]#找到第i个feature的所有可能取值 ,有可能有重复的
- uniqueVals=set(featList)#去重
- newEntroy=0.0
- for value in uniqueVals:
- subDataSet=splitDataSet(dataSet,i,value)
- #Gain(S,A)=Entroy(S)-Sum((Sv/S)*Entroy(Sv)) 信息增益
- prob=len(subDataSet)/float(len(dataSet))#Sv/S
- newEntroy+=prob*calcShannonEnt(subDataSet)
- infoGain=baseEntroy-newEntroy#计算出信息增益
- if (infoGain>bestInfoGain):#若信息增益大于当前最大的信息增益则需要记录最大的信息增益以及其对应的分裂Feature
- bestInfoGain=infoGain
- bestFeature=i
- return bestFeature#返回按照哪个feature进行Split
- #当特征耗费完的时候,dataSet中仍然的分类标签依然不纯,那么就要用少数服从多数的原则来决定分类了
- def majorityCnt(classList):
- classCount={}#每种类别包含的个数
- for vote in classList:
- if vote not in classCount.keys():
- classCount[vote]=0
- classCount[vote]+=1#统计每种类别的个数
- sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
- return sortedClassCount[0][0]
- #创建ID3树
- #首先检测当前数据集是否是纯净的,若是纯净的直接返回类标签
- #再次检测当前数据集的feature是否已经被完全耗费完毕了,若耗费完毕了,直接返回
- #之后构造树形结构,递归构造树
- def createTree(dataSet,labels):
- classList=[example[-1] for example in dataSet]#获取类别标签
- if classList.count(classList[0])==len(classList):#若当前数据集的所有标签全部都一样了
- return classList[0]
- if len(dataSet[0])==1:#特征完全耗费完毕了
- return majorityCnt(dataSet)
- #获取最好的分类feature
- bestFeat=chooseBestFeatureToSplit(dataSet)
- bestFeatLabel=labels[bestFeat]#获取分类属性的具体名字
- myTree={bestFeatLabel:{}}#创建树形结构
- del(labels[bestFeat])#删除分类属性 每次都要耗费掉一个属性
- #获取最好分类Feature所对应的所有值
- featValues=[example[bestFeat] for example in dataSet]
- uniqueVals=set(featValues)
- for value in uniqueVals:
- subLabels=labels[:]#子标签
- myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
- return myTree
- #测试ID3分类效果 inputTree是决策树本身
- #featLabels是人工表示的类别类型
- #testVec是测试向量
- def classify(inputTree,featLabels,testVec):
- firstStr=inputTree.keys()[0]#第一个分类的名字
- secondDict=inputTree[firstStr]#第一个分类名字对应的值{或者是集合或者是标签}
- featIndex=featLabels.index(firstStr)#当前分类对应的feature的下标
- for key in secondDict.keys():
- if testVec[featIndex]==key:#若当前项是集合 递归下去
- if type(secondDict[key]).__name__=='dict':
- classLabel=classify(secondDict[key],featLabels,testVec)
- else:#若当前项是标签
- classLabel=secondDict[key]
- return classLabel
- myDat,labels=createDataSet()
- t=createTree(myDat,labels)
- print t
- labels=['LabelOne','LableTwo']
- print classify(t,labels,[1,0])
- print classify(t,labels,[1,1])
结果>>> ================================ RESTART ================================
>>>
{'LabelOne': {0: 'no', 1: {'LableTwo': {0: 'no', 1: 'yes'}}}}
no
>>> ================================ RESTART ================================
>>>
{'LabelOne': {0: 'no', 1: {'LableTwo': {0: 'no', 1: 'yes'}}}}
no
yes
0 0
- id3算法(python代码)
- id3算法(python代码)
- 决策树ID3代码(Python)
- python实现决策树(ID3算法)
- Python 决策树算法(ID3 & C4.5)
- ID3决策树算法(python实现)
- 决策树(ID3算法)Python实现
- python ID3算法
- 机器学习实战_初识决策树(ID3)算法_理解其python代码(二)
- 决策树ID3算法python实现代码及详细注释
- 分类算法-----决策树(ID3)算法原理和Python实现
- ID3算法简单实例(代码)
- ID3算法的python实现
- python实现决策树ID3算法
- Python实现ID3算法决策树
- Python实现决策树算法ID3
- 决策树ID3 算法python实现
- 决策树(ID3)python
- poj 2909
- 语言差别
- 【HNOI2013】数列
- 数据挖掘笔记-特征选择-算法实现-1
- 深入分析JavaWeb Item17 -- JavaBean组件
- id3算法(python代码)
- 多个Flume合并一个channel上传文件到Hdfs
- C语言
- 决策树
- 【树】二叉树的各种操作
- spring支持JCP的JSR330规范,使用javax.inject
- Vector
- 糯米团—重制“iPhone团购信息客户端”(零)源代码与跳的那些坑和思考
- wamp集成环境下的配置问题 ----显示字符集的问题