python 使用Id3算法实现决策树

来源:互联网 发布:钻石净度分级标准知乎 编辑:程序博客网 时间:2024/06/05 06:44

依然是学习《统计学习方法》一书所做的简单实验,写代码的过程参考了大量其他的博客,本人在此深表感谢。代码实现的依然是书上的例子:
这里写图片描述

import numpy as npimport mathimport operatordef CreateDataSet():    dataset = [ [1, 0,0,0,'no'],                [1, 0,0,1,'no'],                [1, 1,0,1,'yes'],                [1, 1,1,0,'yes'],                [1, 0,0,0,'no'],                [2, 0,0,0,'no'],                [2, 0,0,1,'no'],                [2, 1,1,1,'yes'],                [2, 0,1,2,'yes'],                [2, 0,1,2,'yes'],                [3, 0,1,2,'yes'],                [3, 0,1,1,'yes'],                [3, 1,0,1,'yes'],                [3, 1,0,2,'yes'],                [3, 0,0,0,'no'] ]    labels = ['age', 'job','building','credit']    return dataset, labels#计算香农熵def calcShannonEnt(dataSet):    Ent = 0.0    numEntries = len(dataSet)    labelCounts = {}    for feaVec in dataSet:        currentLabel = feaVec[-1]        if currentLabel not in labelCounts:            labelCounts[currentLabel] = 0        labelCounts[currentLabel] += 1    for key in labelCounts:        prob = float(labelCounts[key])/numEntries        Ent -= prob * math.log(prob, 2)    return Entdef majorityCnt(classList):    classCount = {}    for vote in classList:        if vote not in classCount.keys():            classCount[vote] = 0        classCount[vote] = 1    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)    return sortedClassCount[0][0]def splitDataSet(dataSet,axis,value):    retDataSet=[]    for featVec in dataSet:        if featVec[axis]==value :            reduceFeatVec=featVec[:axis]            reduceFeatVec.extend(featVec[axis+1:])            retDataSet.append(reduceFeatVec)    return retDataSet #返回不含划分特征的子集def choiceBestFea(dataSet):    baseEntropy = calcShannonEnt(dataSet)    numberFeatures = len(dataSet[0]) - 1    bestFeatureId = -1;    bestInfoGain = 0.0    for i in range(numberFeatures):        featList = [example[i] for example in dataSet]        uniqueVals = set(featList)        newEntropy = 0.0        for value in uniqueVals:            subFea = splitDataSet(dataSet,i,value)            prob = len(subFea) / float(len(dataSet))            newEntropy += prob * calcShannonEnt(subFea)        infoGain = baseEntropy - newEntropy        if (infoGain > bestInfoGain):            bestInfoGain = infoGain            bestFeatureId = i    return bestFeatureIddef createDTree(dataSet,labels):    #第一步,判断数据是不是都是同一类的,如果是同一类的,则只有一个节点即根节点    classList = [example[-1] for example in dataSet]    if classList.count(classList[0]) == len(classList):        return classList[0]    # 第二步,判断特征的个数,特征集为空,则只有一个节点即根节点,此时,需要通过投票的方式决定根节点的属性    if len(dataSet[0]) == 1:        return majorityCnt(classList)    # 第三步,通过计算信息增益,选择出最优的特征,也就是信息增益最大的特征    bestFeaId = choiceBestFea(dataSet)    #第四步,选择出信息增益最大的特征,并使用该特征切分数据    bestFeatLabel = labels[bestFeaId]    del (labels[bestFeaId])    featValues = [example[bestFeaId] for example in dataSet]    uniqueVals = set(featValues)    myTree = {bestFeatLabel: {}}    #第五步,递归调用createDTree    for value in uniqueVals:        subLabels = labels[:]        myTree[bestFeatLabel][value] = createDTree(splitDataSet(dataSet, bestFeaId, value), subLabels)    return myTree#输入两个变量(决策树,测试的数据)def classify(inputTree,testVec):    print(inputTree)    firstStr=list(inputTree.keys())[0] #获取树的第一个特征属性    secondDict=inputTree[firstStr] #树的分支,子集合Dict    i=0    classLabel = ""    for key in secondDict.keys():        if testVec[i]==key:            if type(secondDict[key]).__name__=='dict':                classLabel=classify(secondDict[key],testVec)            else:                #表明已经是叶子节点了                classLabel=secondDict[key]                break            i += 1    return classLabeldef storeTree(inputTree,filename):    import pickle    fw=open(filename,'wb') #pickle默认方式是二进制,需要制定'wb'    pickle.dump(inputTree,fw)    fw.close()def reStoreTree(filename):    import pickle    fr=open(filename,'rb')#需要制定'rb',以byte形式读取    return pickle.load(fr)def test():    dataSet,labels = CreateDataSet1()    tree = createDTree(dataSet,labels);    print(tree)    return Nonedef train():    myDat, labels = CreateDataSet()    tree = createDTree(myDat, labels)    storeTree(tree,"dtree.txt")    return Nonedef test():    tree = reStoreTree("dtree.txt")    result = classify(tree,[0,0])    return resultresult = test()print(result)#train()

train()方法用来生成决策树,生成的决策树会被保存在dtree.txt文件中
test()方法用来测试决策树。
从生成的决策树来看,总共只有两个节点。第一个节点是有没有房,第二个节点是有没有工作。所以,测试的时候只需输入【0,0】或者【1,0】这样的长度为2的向量即可。

原创粉丝点击