决策树算法实现

来源:互联网 发布:数控螺纹编程实例 编辑:程序博客网 时间:2024/05/22 12:45
# -*- coding: utf-8 -*-
from math import log
import operator
#计算熵
def clacShannonEnt (dataset):
    numEntries= len(dataset)
    labelCounts={}
    for featVec in dataset:
        currentlabel=featVec[-1]
        if currentlabel not in labelCounts.keys():
            labelCounts[currentlabel]=0
        labelCounts[currentlabel]+=1
    shannonEnt=0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -=prob * log(prob,2)
    return shannonEnt
#构造数据
def createDataSet():
    dataset=[[1,1,'y'],[1,1,'y'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    labels = ['no surfacing','flippers']
    return dataset,labels

myDat,myLabel = createDataSet()

# shan=clacShannonEnt(myDat)
# print shan
# myDat[0][-1]='maybe'
# shan=clacShannonEnt(myDat)
# print shan

#按照给定特征划分数据集,划分数据集是指从特征集迭代出每一个位置的每一种可能的特征值,
#然后在这个位置按照所有可能的情况分开。
def splitDataSet(dataset,axis,value):
    #三个参数为:待划分的数据集,划分数据集的特征,特征的返回值
    #我的理解是:待划分的数据集,选定的数据集中特征的位置,特征位置上的特征值
    retDataSet=[]
    for featVec in dataset:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

splitData=splitDataSet(myDat, 0, 0)
# print splitData

#选择最好的数据集划分方式,根据信息增益
def chooseBestFeatureToSplit(dataset):
    numFeatures=len(dataset[0])-1
    #未划分数据集时候的信息熵
    baseEntropy=clacShannonEnt(dataset)
    bestInfoGain=0.0
    bestFeature=-1
    for i in range(numFeatures):
        #列表推导,从列表推导出新的列表
        featList = [example[i] for example in dataset]
        #set集合可以把所有重复元素剔除掉
        uniqueVals=set(featList)
        newEntropy=0.0
        for value in uniqueVals:
            subDataSet=splitDataSet(dataset, i, value)
            prob=len(subDataSet)/float(len(dataset))
            #计算每一种划分方式的信息熵,划分方式指的是选定哪一个特征位置划分,即i的位置
            newEntropy +=prob*clacShannonEnt(subDataSet)
        #信息增益
        infoGain=baseEntropy-newEntropy
        if(infoGain>bestInfoGain):
            bestInfoGain=infoGain
            bestFeature=i
    return bestFeature
        
# bestFeatureMyDat=chooseBestFeatureToSplit(myDat)
# print bestFeatureMyDat

#建立多数表决函数
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]    

#建立树
def createTree(dataset,labels):
    classList =[example[-1] for example in dataset]
    #count() 方法用于统计某个元素在列表中出现的次数。
    #如果第一行的类别个数总和与列表行数总和相等,说明列表中所有的类别都与第一个相同,就是类别完全相同,停止划分
    if classList.count(classList[0])==len(classList):
        return classList[0]
    #遍历完所有特征时返回出现次数最多
    if len(dataset[0])==1:
        return majorityCnt(classList)

    bestFeat=chooseBestFeatureToSplit(dataset)
    #数据集并不含label,bestFeat是label的位置
    bestFeatLabel=labels[bestFeat]
    #选好了当前最好的划分点后构造以最佳分割label即bestFeatLabel的树
    myTree={bestFeatLabel:{}}
    #然后删除这个分类节点,因为已经使用过
    del(labels[bestFeat])
    #得到列表中所有的属性值
    featValues=[example[bestFeat] for example in dataset]
    uniqueValue=set(featValues)
    for value in uniqueValue:
        subLabels=labels[:]
        #最后按照最佳分割点位置上的值进行分类,有几个值就又建几棵树。递归
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataset, bestFeat, value), subLabels)
    return myTree
 
myTree=createTree(myDat, myLabel)       
# print myTree

def classify(inputTree,featLabels,testVec):
    firstStr=inputTree.keys()[0]
    secondDict=inputTree[firstStr]
    featIndex=featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex]==key:
            #如果测试向量的特征值仍然是字典
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key], featLabels, testVec)
            else:
                classLabel=secondDict[key]
    return classLabel
myDat,myLabel = createDataSet()
classLabel=classify(myTree, myLabel, [0,1])
# print classLabel

#运用pickle序列化对象
def storeTree(inputTree,filename):
    import pickle
    fw=open(filename,'w')
    pickle.dumps(inputTree, fw)
    fw.close()
    
def grabTree(filename):
    import pickle
    fr=open(filename)
    return pickle.load(fr)

# myTree=createTree(myDat, myLabel)      
# storeTree(myTree,'D:/learn/Ch02/classifierStorage.txt')
# print grabTree('D:/learn/Ch02/classifierStorage.txt')

def glass(filename):
    fr=open(filename)
    lenses=[inst.strip().split('\t') for inst in fr.readlines()]
    lensesLabel=['age','prescript','astigmatic','tearRate']
    lensesTree=createTree(lenses, lensesLabel)
    return lensesTree

lensesTree=glass('D:/learn/Ch02/lenses.txt')
print lensesTree
0 0
原创粉丝点击