机器学习实战---读书笔记: 第3章 决策树
来源:互联网 发布:印度的种姓制度 知乎 编辑:程序博客网 时间:2024/06/01 12:53
内容来源于书《机器学习实战》
# *-* coding: utf-8 *-* '''<<机器学习实战>> ---读书笔记: 第3章 决策树关键:1 决策树基础知识:决策树任务:理解数据中蕴含的知识,提取规则。应用:专家系统优点:复杂度不高,中间值缺失不敏感缺点:容易产生过拟合适用:数值型和标称型决策树构造过程:找到当前能够最大区分数据的特征,将数据划分;如果划分后的每一类数据都属于同一类,则停止划分并设置类标签;否则,对非同一类的数据集再次重复上述过程。长方形:判断模块,椭圆:终止模块,左右箭头:分支2信息增益:划分前后信息的变化。信息增益最高的特征就是最好选择,信息增益表示数据无序度的减少熵:信息的期望值,熵越大,混合数据越多。实际需要求得使得熵最小的划分特征符号Xi的信息为l(Xi) = -log2 P(Xi) , P(Xi)是选择该分类的概率熵: H = -P(Xi) * log2 P(Xi) i从1到n3选择最好的特征过程:选择最好的特征进行划分数据集。具体过程是:遍历每个特征,收集每个特征所有取值的集合,计算该特征每个取值对应的信息,累加后得到该特征的熵。如果该原始熵-当前熵的结果大于信息增益,更新信息增益,并记录该最好特征4创建决策树如果类别标签相同,直接返回,类别标签。否则,如果所有特征用完,选择次数最多的类别。 这里的类别应该是: yes 或者 no计算能够划分得到最大信息增益的特征,然后获取特征的所有取值,遍历每个取值,递归得对每个取值下的数据集进行划分。构建出: 当前特征对应的映射5决策树的存储构造决策树耗时,每次分类时调用已经构造好的决策树,使用pickle序列化对象。序列化对象:在磁盘保存 #必须以二进制形式保存 fw = open(fileName , "wb" ) # pickle.dump(obj , file, protocol) :件对象保存到文件中,pickle可以事先基本数据的序列和反序列化 pickle.dump(inputTree ,fw) fw.close() fr = open(fileName , "rb") # pickle.load(file):从文件中读取字符串,重构为原来的python对象 return pickle.load(fr)'''from math import logimport operatorimport matplotlib.pyplot as plt#计算给定数据集的香农熵def calcShannonEnt(dataSet): rows = len(dataSet) #统计每个类别出现的概率 labelToCount = dict() for data in dataSet: label = data[-1] if label in labelToCount: labelToCount[label] += 1 else: labelToCount[label] = 1 #计算香农熵: H = - P(Xi) * log2 P(Xi) result = 0.0 for label , count in labelToCount.items(): # //返回整数, /返回浮点数,一般用/ prob = count * 1.0 / rows result -= prob * log(prob , 2) return result#按照给定特征划分数据集,实际就是遍历,根据给定的列号,对应的列值,生成除该列以外的划分向量def splitDataSet(dataSet , columnNum , value): resultDatas = [] for data in dataSet: if value == data[columnNum]: front = data[ : columnNum] back = data[columnNum + 1 : ] front.extend(back) resultDatas.append(front) return resultDatasdef createDataSet(): dataSet = [ [1, 1, 'yes'] , [1, 1, 'yes'], [1, 0 , 'no'], [0, 1 , 'no'], [0, 1 , 'no'] ] labels = ['no surfacting' , 'flippers'] return dataSet , labelsdef calcShannonEnt_test(): dataSet , labels = createDataSet() result = calcShannonEnt(dataSet) print(result)'''选择最好的特征进行划分数据集。具体过程是:遍历每个特征,收集每个特征所有取值的集合,计算该特征每个取值对应的信息,累加后得到该特征的熵。如果该原始熵-当前熵的结果大于信息增益,更新信息增益,并记录该最好特征'''def chooseBestFeature(dataSet): rows = len(dataSet) featureNum = len(dataSet[0]) - 1 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0 bestFrature = -1 #遍历每个特征,对每个特征计算熵 for i in range(featureNum): features = [ temp[i] for temp in dataSet ] featureValues = set(features) #根据特征取值,划分数据集,计算划分后的数据集的熵 newEntropy = 0.0 for value in featureValues: subDatas = splitDataSet(dataSet , i , value) prob = float( len(subDatas) / rows ) newEntropy += prob * calcShannonEnt(subDatas) infoGain = baseEntropy - newEntropy if (infoGain > bestInfoGain): bestInfoGain = infoGain bestFrature = i return bestFraturedef chooseBestFeature_test(): myDat , labels = createDataSet() bestFeature = chooseBestFeature(myDat) print(bestFeature)#统计得到<类别, 出现次数>这样的映射,选择出出现次数最多的类别作为返回def majorityCount(classList): labelToCount = {} for vote in classList: if vote in labelToCount: labelToCount[vote] += 1 else: labelToCount[vote] = 1 sortedResult = sorted(labelToCount.items() , key=operator.itemgetter(1) , reversed=True) return sortedResult[0][0]'''创建决策树:如果类别标签相同,直接返回,类别标签。否则,如果所有特征用完,选择次数最多的类别。 这里的类别应该是: yes 或者 no计算能够划分得到最大信息增益的特征,然后获取特征的所有取值,遍历每个取值,递归得对每个取值下的数据集进行划分。构建出: 当前特征对应的映射'''def createTree(dataSet , labels): #所谓的类别信息就是划分后的几个类别,比如yes,no ;或者: 猫,狗,牛 等类别信息 ; 但是标签似乎和类别是相同的说法 classList = [ temp[-1] for temp in dataSet ] #如果只有一个类别,说明之前经过某个特征值划分后的数据集只有一个类别,直接返回该类别 if classList.count(classList[0]) == len(classList): return classList[0] #如果所有特征都用完,选择出现次数最多的类别 if 1 == len(classList[0]) : return majorityCount(classList) #选择能够带来最大信息增益的特征,并按照该特征值划分得到的子数据集 重复上述操作 bestFeature = chooseBestFeature(dataSet) features = [ temp[bestFeature] for temp in dataSet ] uniqueFratures = set(features) bestFeatureLabel = labels[bestFeature] decisionTree = {bestFeatureLabel : {} } #需要删除最优特征对应的标签 del labels[bestFeature] for value in uniqueFratures: subLabels = labels[ : ] subdatas = splitDataSet(dataSet , bestFeature , value) decisionTree[bestFeatureLabel][value] = createTree(subdatas , subLabels) return decisionTreedef decisionTree_test(): dataSet ,labels = createDataSet() decisionTree = createTree(dataSet , labels) print(decisionTree)decisionNode = dict(boxstyle='sawtooth' , fc='0.8')leafNode = dict(boxstyle="round4", fc='0.8')arrow_args = dict(arrowstyle="<-")def plotNode(nodeText , centerPoint , parentPoint , nodeType): # 注解(文本,起始结点,坐标,xy文本,文本坐标,垂直,水平居中,矩形样式,箭头样式) createPlot.ax1.annotate(nodeText , xy=parentPoint , xycoords="axes fraction", xytext=centerPoint , textcoords="axes fraction" , va="center" , ha="center" , bbox=nodeType , arrowprops=arrow_args)def createPlot(): fig = plt.figure(1 , facecolor='white' ) fig.clf() # subplot(行数,列数,编号) ,frameon是否显示网格 createPlot.ax1 = plt.subplot(111, frameon=False) plotNode("决策结点" , (0.5 , 0.1) , (0.1 , 0.5) , decisionNode) plotNode("叶节点" , (0.8,0.1) , (0.3 , 0.8) , leafNode ) plt.show()#获取叶节点个数,通过判断对应{key, val{}}中val中每个键对应的值如果是字典就递归累加;否则表明是孩子结点def getNumLeafs(myTree): leafNum = 0 #python3.x keys()返回字典 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key , value in secondDict.items() : if isinstance(value , dict): leafNum += getNumLeafs(value) else: leafNum += 1 return leafNum#获取树的层数,不断累加当前层数,选取层数中大者返回def getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] depth = 0 for key , value in secondDict.items(): if isinstance(value ,dict): depth = 1 + getTreeDepth(value) #叶子结点高度为1 else: depth = 1 if depth > maxDepth: maxDepth = depth return maxDepthdef getTreeDepth_test(): dataSet ,labels = createDataSet() decisionTree = createTree(dataSet , labels) #print(decisionTree) leafNum = getNumLeafs(decisionTree) depth = getTreeDepth(decisionTree) print("leaf num: %d , depth: %d" % (leafNum , depth))#父子结点中间填充文本def plotMidText(centerPoint , parentPoint , textString): xMid = (parentPoint[0] - centerPoint[0]) / 2.0 + centerPoint[0] yMid = (parentPoint[1] - centerPoint[1]) / 2.0 + centerPoint[1] createPlot.ax1.text(xMid , yMid , textString)#计算宽和高# 绘制决策树def plotTree(myTree,parentPt,nodeTxt): # 定义并获得决策树的叶子结点数 numLeafs = getNumLeafs(myTree) #depth = getTreeDepth(myTree) # 得到第一个特征 firstStr = list(myTree.keys())[0] # 计算坐标,x坐标为当前树的叶子结点数目除以整个树的叶子结点数再除以2,y为起点 cntrPt = (plotTree.xOff + (1.0 +float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff) # 绘制中间结点,即决策树结点,也是当前树的根结点,这句话没感觉出有用来,注释掉照样建立决策树,理解浅陋了,理解错了这句话的意思,下面有说明 plotMidText(cntrPt, parentPt, nodeTxt) # 绘制决策树结点 plotNode(firstStr,cntrPt,parentPt,decisionNode) # 根据firstStr找到对应的值 secondDict = myTree[firstStr] # 因为进入了下一层,所以y的坐标要变 ,图像坐标是从左上角为原点 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD # 遍历secondDict for key in secondDict.keys(): # 如果secondDict[key]为一棵子决策树,即字典 if type(secondDict[key]).__name__ == 'dict': # 递归的绘制决策树 plotTree(secondDict[key],cntrPt,str(key)) # 若secondDict[key]为叶子结点 else: # 计算叶子结点的横坐标 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW # 绘制叶子结点 plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt, leafNode) # 这句注释掉也不影响决策树的绘制,自己理解的浅陋了,这行代码是特征的值 plotMidText((plotTree.xOff,plotTree.yOff), cntrPt, str(key)) # 计算纵坐标 plotTree.yOff = plotTree.yOff +1.0/plotTree.totalDdef createPlot(inTree): # 定义一块画布(画布是自己的理解) fig = plt.figure(1,facecolor='white') # 清空画布 fig.clf() # 定义横纵坐标轴,无内容 axprops = dict(xticks=[],yticks=[]) # 绘制图像,无边框,无坐标轴 createPlot.ax1 = plt.subplot(111,frameon=True,**axprops) # plotTree.totalW保存的是树的宽 plotTree.totalW = float(getNumLeafs(inTree)) # plotTree.totalD保存的是树的高 plotTree.totalD = float(getTreeDepth(inTree)) # 决策树起始横坐标 plotTree.xOff = - 0.5 / plotTree.totalW #从0开始会偏右 #print(plotTree.xOff) # 决策树的起始纵坐标 plotTree.yOff = 1.0 # 绘制决策树 plotTree(inTree,(0.5,1.0),'') # 显示图像 plt.show()#使用决策树的分类函数:比较测试数据与决策树上的数值,递归执行过程直到进入叶子结点,将测试数据定义为叶子结点所属的类型def classify(inputTree , featLabels , testVec): firstStr = list(inputTree.keys())[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) for key , value in secondDict.items(): #找到属性对应的值 if testVec[featIndex] == key: if isinstance(value , dict): classLabel = classify(value , featLabels , testVec) else: classLabel = value return classLabel'''决策树的存储:构造决策树耗时,每次分类时调用已经构造好的决策树,使用pickle序列化对象。序列化对象:在磁盘保存 #必须以二进制形式保存 fw = open(fileName , "wb" ) # pickle.dump(obj , file, protocol) :件对象保存到文件中,pickle可以事先基本数据的序列和反序列化 pickle.dump(inputTree ,fw) fw.close() fr = open(fileName , "rb") # pickle.load(file):从文件中读取字符串,重构为原来的python对象 return pickle.load(fr)'''def storeTree(inputTree , fileName): import pickle #必须以二进制形式保存 fw = open(fileName , "wb" ) # pickle.dump(obj , file, protocol) :件对象保存到文件中,pickle可以事先基本数据的序列和反序列化 pickle.dump(inputTree ,fw) fw.close()def grabTree(fileName): import pickle fr = open(fileName , "rb") # pickle.load(file):从文件中读取字符串,重构为原来的python对象 return pickle.load(fr)#鱼分类问题def fishClassify(): #calcShannonEnt_test() #chooseBestFeature_test() #decisionTree_test() #createPlot() #getTreeDepth_test() dataSet ,labels = createDataSet() copyLabels = labels[:] #注意构建决策树会使得原来标签集发生改变,这里需要传入一个副本 decisionTree = createTree(dataSet , copyLabels) #序列化保存 fileName = "classifierStorage.txt" storeTree(decisionTree , fileName) decisionTree = grabTree(fileName) print(labels) result = classify(decisionTree , labels , [1, 0]) print(result) result = classify(decisionTree , labels , [1, 1]) print(result) createPlot(decisionTree)#镜片分类问题def lenseClassify(): fr = open('lenses.txt') lenses = [line.strip().split("\t") for line in fr.readlines()] lensesLabels = ["age" , "prescript" , "astigmatic" , "tearRate"] lenseTree = createTree(lenses , lensesLabels) createPlot(lenseTree)if __name__ == "__main__": fishClassify() lenseClassify()
0 0
- 机器学习实战---读书笔记: 第3章 决策树
- 读书笔记:机器学习实战【第3章 决策树】
- 机器学习实战第3章决策树
- 【机器学习实战】第3章 决策树
- 【读书笔记】机器学习实战-第三章 决策树
- 《机器学习实战》读书笔记 第三章 决策树(part 3)
- 机器学习实战读书笔记-决策树
- 代码注释:机器学习实战第3章 决策树
- 机器学习实战(第3章 决策树)
- 【机器学习实战】第3章 决策树(DecisionTree)
- 机器学习实战第3章-决策树(decision tree)
- 【读书笔记】机器学习实战 第7章 基于单层决策树的adaboost
- 《机器学习实战》读书笔记 第三章 决策树(part 1)
- 机器学习实战-第三章决策树-代码理解-读书笔记
- 机器学习实战第三章——决策树,读书笔记
- 【读书笔记】机器学习实战-决策树(1)
- 【读书笔记】机器学习实战-决策树(2)
- 《机器学习》读书笔记 6 第4章 决策树
- VC++ 通过ADO连接数据库查询时返回空值报错的解决方案
- 二叉树中和为某一值的路径
- Android自带的倒计时CountDownTimer
- HTML 提高页面加载速度的方法
- TP3.2的URL重写省略index.php问题
- 机器学习实战---读书笔记: 第3章 决策树
- [Weex-BBQ]Weex项目中引用css样式的三种姿势
- Oracle中查看所有表和字段以及表注释.字段注释
- s标签大全
- Android5.0之NavigationView的使用
- 模板的分离编译
- 【智库2861】大数据预测:天气和蛋挞的关系?是关联性还是因果关系
- 用chrome测试请求报文
- Python 输出相关内容