『机器学习实战』决策树

来源:互联网 发布:淘宝上买dnf账号 编辑:程序博客网 时间:2024/06/07 09:52

代码:

#! /usr/bin/env python# coding: utf-8from math import logimport operatordef calcShannonEnt(dataSet):numEntries = len(dataSet)labelCounts = {}for featVec in dataSet:currentLabel = featVec[-1]if currentLabel not in labelCounts:labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key]) / numEntriesshannonEnt -= prob * log(prob, 2)return shannonEntdef createDataSet():dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]labels = ['no surfacing', 'flippers']return dataSet, labelsdef splitDataSet(dataSet, axis, value):retDataSet = []for featVec in dataSet:if featVec[axis] == value:reducedFeatVec = featVec[: axis]reducedFeatVec.extend(featVec[axis + 1: ])retDataSet.append(reducedFeatVec)return retDataSetdef chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0]) - 1baseEntropy = calcShannonEnt(dataSet)bestInfoGain = 0.0bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]uniqueVals = set(featList)newEntropy = 0.0for value in uniqueVals:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet) / float(len(dataSet))newEntropy += prob * calcShannonEnt(subDataSet)infoGain = baseEntropy - newEntropyif (infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = i return bestFeaturedef majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount.keys():classCount[vote] += 1sortedClassCount = sorted(classCount.iteritems, key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]def createTree(dataSet, labels):classList = [example[-1] for example in dataSet]if classList.count(classList[0]) == len(classList):return classList[0]if len(dataSet[0]) == 1:return majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet)bestFeatLabel = labels[bestFeat]myTree = {bestFeatLabel: {}}del labels[bestFeat]featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:subLabels = labels[:]myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)return myTree

运行代码:

import trees_LmyDat, labels = trees_L.createDataSet()print myDatprint labelsprint trees_L.calcShannonEnt(myDat)#myDat[0][-1] = 'maybe'print trees_L.calcShannonEnt(myDat)myTree = trees_L.createTree(myDat, labels)print myTree


原创粉丝点击