机器学习实战
来源:互联网 发布:vs 变量已被优化掉 编辑:程序博客网 时间:2024/06/05 19:50
借用目前百科的解释:决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。
在构建决策树的过程中,其伪代码可表示如下:
检测数据集中的每个子项是否属于同一分类: if so return 类标签 else 寻找划分数据集的最好的特征 划分数据集 创建分支节点 for 每个划分的子集 调用本函数并增加返回结果到分支节点中 return 分支节点
那么当数据中存在多个特征的时候,我们如何选取最合适的特征进行分类呢?比如下图:
我们如何选择“不浮出水面是否可以生存”和“是否有脚蹼”,我们首先选择哪个特征进行分类呢?这里我们就要介绍一个一个叫“信息增益”的属性,我们划分数据是为了使无序的数据更加有序,组织杂乱无章数据的一种方法是使用信息论度量信息,在信息论中,划分数据之前之后信息发生的变成称之为“信息增益”,那么我们现在就要计算信息增益,信息增益最高的特征就是最好的选择。
信息增益的度量方式成为香农熵简称“熵”,其计算公式为
在每次遍历选取特征时,我们除开此特征之后计算其熵值,每个特征都要遍历,然后看哪个特征的熵值最小,就选取此特征进行划分,一直遍历到只有一个数据或者所有的数据分类便签都是一样的为止。那么,为什么要选取计算得到的新熵值最小的呢?因为熵值越小代表数据的混乱程度越低, 反之,熵值越大,代表数据的混乱程度越高,我们进行数据划分为了就是使数据的混乱程度小。
(这里使用的是Python3.6,如果python2.7出现中文问题可以在最前面加上一句#encoding=utf-8):
trees.py
from math import logimport operatordef createDataSet(): dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] labels = ['no surfacing', 'flippers'] return dataSet, labels"""计算香农熵"""def calcShannonEnt(dataSet): length = len(dataSet) classCount = {} for i in range(length): if dataSet[i][-1] not in classCount.keys(): classCount[dataSet[i][-1]] = 0 classCount[dataSet[i][-1]] += 1 shannonEnt = 0.0 for classify in classCount: prob = float(classCount[classify]) / length shannonEnt -= prob * log(prob, 2) return shannonEnt"""划分数据集, 根据axis轴的value值划分"""def splitDataSet(dataSet, axis, value): returnDataSet = [] for data in dataSet: if data[axis] == value: reducedFeatVec = data[:axis] reducedFeatVec.extend(data[axis + 1:]) returnDataSet.append(reducedFeatVec) return returnDataSet"""根据计算的熵最小的特征进行划分数据集"""def chooseBestFeatureToSplit(dataSet): featureNum = len(dataSet[0]) - 1 length = len(dataSet) """ 之所以要获得初始的香农熵, 是方便后续的比较, 对于不同的数据初始的熵值不一样, 难以初始化 事后仔细考虑了一下, 发现还有另外一个比较方法, 如下: baseEntropy = calcShannonEnt(dataSet) bestFeature = -1 bestShannonEnt = baseEntropy 这里不同 for i in range(featureNum): featureList = [example[i] for example in dataSet] uniqueFeature = set(featureList) newShannonEnt = 0.0 for feature in uniqueFeature: returnDataSet = splitDataSet(dataSet, i, feature) prob = len(returnDataSet) / length newShannonEnt += prob * calcShannonEnt(returnDataSet) if newShannonEnt < bestShannonEnt: 这里不同 bestShannonEnt =newShannonEnt 这里不同 bestFeature = i return bestFeature """ baseEntropy = calcShannonEnt(dataSet) bestFeature = -1 bestShannonEnt = 0.0 for i in range(featureNum): featureList = [example[i] for example in dataSet] uniqueFeature = set(featureList) newShannonEnt = 0.0 for feature in uniqueFeature: returnDataSet = splitDataSet(dataSet, i, feature) prob = len(returnDataSet) / length newShannonEnt += prob * calcShannonEnt(returnDataSet) if baseEntropy - newShannonEnt > bestShannonEnt: bestShannonEnt = baseEntropy - newShannonEnt bestFeature = i return bestFeature"""获取出现次数最多的分类"""def majorityCnt(dataSet): classCount = {} for value in majorityCnt: if value not in classCount.keys(): classCount[value] = 0 classCount += 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]"""递归创建树"""def createTree(dataSet, labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList): return classList[0] if len(dataSet[0]) == 1: return majorityCnt(dataSet) bestFeatrue = chooseBestFeatureToSplit(dataSet) bestLabel = labels[bestFeatrue] del (labels[bestFeatrue]) myTree = {bestLabel: {}} featureValues = [example[bestFeatrue] for example in dataSet] uniqueValues = set(featureValues) for feature in uniqueValues: subLabels = labels[:] myTree[bestLabel][feature] = createTree(splitDataSet(dataSet, bestFeatrue, feature), subLabels) return myTree"""使用决策树进行分类"""def classify(inputTree, featLabels, testVec): firstStr = list(inputTree.keys())[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) for key in secondDict.keys(): if key == testVec[featIndex]: if type(secondDict[key]).__name__ == 'dict': classLabel = classify(secondDict[key], featLabels, testVec) else: classLabel = secondDict[key] return classLabel"""将树存储在本地, python3的pickle存储的是二进制所以存储跟读取使用wb和rb"""def storeTree(inputTree, filename): import pickle fw = open(filename, 'wb') pickle.dump(inputTree, fw) fw.close()"""从本地获取出树, python3的pickle存储的是二进制所以存储跟读取使用wb和rb"""def grabTree(filename): import pickle fw = open(filename, 'rb') return pickle.load(fw)
treePlotter.py
import matplotlib.pyplot as pltdicisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-")"""画箭头、文字、还有框框"""def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)"""该函数是一个示例函数"""def createPlot(): fig = plt.figure(1, facecolor="white") fig.clf() createPlot.ax1 = plt.subplot(111, frameon=False) plotNode('decisionNode', (0.5, 0.1), (0.1, 0.5), dicisionNode) plotNode('leafNode', (0.8, 0.1), (0.3, 0.8), leafNode) plt.show()"""递归得到树的叶子节点个数"""def getLeafsNum(myTree): leafsNum = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': leafsNum += getLeafsNum(secondDict[key]) else: leafsNum += 1 return leafsNum"""递归得到树的深度"""def getTreeDeapth(myTree): maxDeapth = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDeapth = 1 + getTreeDeapth(secondDict[key]) else: thisDeapth = 1 if thisDeapth > maxDeapth: maxDeapth = thisDeapth return maxDeapthdef retrieveTree(i): listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}] return listOfTrees[i]"""在两个节点之间打印特征值, 这里是0|1"""def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeText): leafsNum = getLeafsNum(myTree) depth = getTreeDeapth(myTree) """ 得到一个字符串键值(第0个) {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}} 则第一次得到的是no surfacing """ firstStr = list(myTree.keys())[0] """当前节点的坐标""" cntrPt = (plotTree.xOff + (1.0 + float(leafsNum)) / 2.0 / plotTree.totalW, plotTree.yOff) """在两个节点之间画0|1""" plotMidText(cntrPt, parentPt, nodeText) plotNode(firstStr, cntrPt, parentPt, dicisionNode) """当前节点的子节点""" secondDict = myTree[firstStr] """因为画子节点所以y值要减小向下移""" plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD for key in secondDict.keys(): """如果子节点不是叶子节点, 就递归调用画子树""" if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key], cntrPt, str(key)); else: plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) """画完当前子节点所以y值要增大向上移回去""" plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD"""绘制图形"""def createPlot(inTree): """这里figure第一个参数是窗口的名字figure 1, facecoloe为背景色""" fig = plt.figure(1, facecolor='white') fig.clf() """axprops是横竖轴需要出现的坐标, 可以尝试axprops = dict(xticks=[0, 0.5, 1], yticks=[0, 1])""" axprops = dict(xticks=[], yticks=[]) """ plt.subplot第一个参数其实是三个参数nmp, 表示分割成n*m个图, 当前图的编号是p frameon表示是否有方框包围坐标轴 最后一个是坐标的参数 """ createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) """宽度""" plotTree.totalW = float(getLeafsNum(inTree)) """深度""" plotTree.totalD = float(getTreeDeapth(inTree)) """ 当前递归到的叶子节点的上一个叶子节点的坐标 初始为第一个叶子节点的x坐标的上一个坐标, 也就是第一个叶子节点的x坐标-1/plotTree.totalW, 可以结合整个画图的过程理解一下,比较难解释 """ plotTree.xOff = -0.5 / plotTree.totalW """初始y坐标""" plotTree.yOff = 1.0 """画第一个图, 坐标是(0.5, 1.0)""" plotTree(inTree, (0.5, 1.0), '') plt.show()
下面是我的一些测试代码test.py
import treesimport treePlotter# dataSet, labels = trees.createDataSet()# shannonEnt = trees.calcShannonEnt(dataSet)# print(trees.createTree(dataSet, labels))## treePlotter.createPlot()# myTree = treePlotter.retrieveTree(0)# print(trees.classify(myTree, labels, [1, 0]))# print(trees.classify(myTree, labels, [1, 1]))# treePlotter.createPlot(myTree)# trees.storeTree(myTree, 'classifierStorage.txt')# print(trees.grabTree('classifierStorage.txt'))## fr = open('lenses.txt')# lenses = [inst.strip().split('\t') for inst in fr.readlines()]# lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']# lensesTree = trees.createTree(lenses, lensesLabels)# print(lensesTree)# treePlotter.createPlot(lensesTree)myDat, labels = trees.createDataSet()print(trees.chooseBestFeatureToSplit(myDat))
0 0
- 《机器学习实战》学习
- R:机器学习实战
- 《机器学习实战》读书笔记
- 机器学习实战---决策树
- 机器学习实战
- 《机器学习实战》读书笔记
- 机器学习实战
- 机器学习实战笔记
- 机器学习实战-决策树
- 机器学习实战---决策树
- 机器学习实战 决策树
- 机器学习实战
- [机器学习实战]-决策树
- 机器学习实战
- 机器学习实战-第一章
- 《机器学习实战》--KNN
- 机器学习实战--adaboost
- 机器学习实战--svm
- 快速排序
- html+css+javascript编程实战项目及心得
- 剑指offer算法 java实现 替换字符串空格
- AndroidStudio:Error running app: Default Activity Not Found
- Linux中的进程通信(二)--信号量
- 机器学习实战
- cocos 3.14 eclpise 环境配置 打包
- ASP.Net学习笔记013--ViewState初探2
- 分库分表的几种常见形式以及可能遇到的难题
- html
- MySQL命令整理
- 51nod1021 石子归并
- Activity starting window and how to speed up activity starting (Android 7.0)
- 剑指offer-12.数值的整数次方