机器学习之决策树实现(Python)

来源:互联网 发布:淘宝一颗钻 编辑:程序博客网 时间:2024/04/30 05:28

决策树的一般流程

(1) 收集数据

(2) 准备数据

(3) 分析数据

(4)训练算法

(5) 测试算法

(6) 使用算法

由于还是刚刚入门,这里都是搬砖操作,来源于Machine Learning in Action 

决策树有多种ID3.CART,C4.5/5.0等等这里实现的ID3


这里大家信息熵,信息增益大家都了解把

from math import logimport operatordef calShannonEnt(dataSet):    #计算信息熵(shannon value)    numEntries = len(dataSet)    labelCounts = {}    for feaVec in dataSet:        currentLabel = feaVec[-1]        if currentLabel not in labelCounts.keys():            labelCounts[currentLabel] = 0        labelCounts[currentLabel] += 1    shannonEnt = 0.0    for key in labelCounts:        prob = float(labelCounts[key])/numEntries        shannonEnt -= prob*log(prob,2) #get the log value    return shannonEnt#创建数据的函数def createDataSet():    dataSet = [[1,1,'yes'],               [1,1,'yes'],               [1,0,'no'],               [0,1,'no'],               [0,1,'no']]    labels = ['no surfacing','flippers']    return dataSet,labels#划分数据集,按照给定的特征划分数据集def splitDataSet(dataSet,axis,value):    retDataSet = []    for featVec in dataSet:        if featVec[axis] == value:            reducedFeatVec = featVec[:axis]            reducedFeatVec.extend(featVec[axis+1:])            retDataSet.append(reducedFeatVec)    return retDataSet#选择最好的数据划分方式def chooseBestFeatureToSplit(dataSet):    numFeatures = len(dataSet[0]) - 1    baseEntropy = calShannonEnt(dataSet)    bestInfoGain = 0.0;bestFeature = -1    for i in range(numFeatures):        featList = [example[i] for example in dataSet]        uniqueVals = set(featList)        newEntropy = 0.0        for value in uniqueVals:            subDataSet = splitDataSet(dataSet,i,value)            prob = len(subDataSet)/float(len(dataSet))            newEntropy += prob * calShannonEnt(subDataSet)        infoGain = baseEntropy - newEntropy        if (infoGain > bestInfoGain):            bestInfoGain = infoGain            bestFeature = i    return bestFeature#递归创建树,用于找出出现次数最多的分类名称的函数def majorityCnt(classList):    classCount = {}    for vote in classList:        if vote not in classCount.keys():classCount[vote] = 0        classCount[vote] += 1    sortedClassCount = sorted(classCount.iteritems(),key = operator.itemgetter(1),reversed = True)    return sortedClassCount[0][0]     #用于创建树的函数代码def createTree(dataSet,labels):    classList = [example[-1] for example in dataSet]    #the type is same so stop classify    if classList.count(classList[0]) == len(classList):        return classList[0]    # traversal all the features and choose the most frequent feature    if (len(dataSet[0]) == 1):        return majorityCnt(classList)    bestFeat = chooseBestFeatureToSplit(dataSet)    bestFeatLabel = labels[bestFeat]    myTree = {bestFeatLabel:{}}    del(labels[bestFeat])    #get the list which attach the whole properties    featValues = [example[bestFeat] for example in dataSet]    uniqueValse = set(featValues)    for value in uniqueValse:        subLabels = labels[:]        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)    return myTreemyDat,labels = createDataSet()myTree = createTree(myDat, labels)print(myTree)


0 0