机器学习实战:决策树(decision Trees)

来源:互联网 发布:apache spark mahout 编辑:程序博客网 时间:2024/05/29 17:10


from numpy import *from math import logimport operatordef calcShannonEnt(dataSet):    num=len(dataSet)    labelCount={}    for data in dataSet:        currentLabel = data[-1]        #if currentLabel not in labelCount.keys():        #    labelCount[currentLabel]=0        #labelCount[currentLabel]+=1        labelCount[currentLabel]=labelCount.get(currentLabel,0)+1         shannonEnt=0.0    for key in labelCount:        p=float(labelCount[key])/num        shannonEnt -= p*log(p,2)    return shannonEnt       def createDataSet():    dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]    labels=['no surfacing','flippers']    return dataSet,labels    def splitDataSet(dataSet, axis, value):    retDataSet=[]    for featureVec in dataSet:        if featureVec[axis]==value:            temp=featureVec[:axis]            temp.extend(featureVec[axis+1:])            retDataSet.append(temp)    return retDataSet    def chooseBestFeatureToSplit(dataSet):    EntD=calcShannonEnt(dataSet)    feaNo=len(dataSet[0])-1    bestFeature=-1    bestEntD=-1    for i in range(feaNo):        feati=[example[i] for example in dataSet]        uniqueVals=set(feati)        subEnt=0.0        for value in uniqueVals:            subDataSet=splitDataSet(dataSet, i, value)            p=len(subDataSet)/float(len(dataSet))            subEnt+=p*calcShannonEnt(subDataSet)        newEnt=EntD-subEnt        if newEnt > bestEntD:            bestEntD=newEnt            bestFeature=i    return bestFeature    def majorityCnt(classList):    classCount={}    for item in classList:        classCount[item]=classCount.get(item)+1    sortedClass=sorted(classCount.iteritems,key=operator.itemgetter(1),reverse=True)    return sortedClass[0][0]    def createTree(dataSet,labels):    classList=[item[-1] for item in dataSet]    if len(set(classList))==1:        return classList[0]    if len(dataSet[0])==1:        return majorityCnt(classList)    bestFeature=chooseBestFeatureToSplit(dataSet)    bestLabel=labels[bestFeature]    bFeatureItems=[example[bestFeature] for example in dataSet]    uniqueVals=set(bFeatureItems)    trees={bestLabel:{}}    del(labels[bestFeature])    for value in uniqueVals:        subDataSet=splitDataSet(dataSet,bestFeature, value)        subLabels=labels[:]        trees[bestLabel][value]=createTree(subDataSet,subLabels)    return treesdef classify(inputTree,featLabels,testVec):    firstStr=inputTree.keys()[0]    secTree=inputTree[firstStr]    try:        featIndex=featLabels.index(firstStr)    except ValueError:        print("List does not contain value")    for key in secTree.keys():        if testVec[featIndex]==key:            if type(secTree[key]).__name__ == 'dict':                result=classify(secTree[key],featLabels,testVec)            else:                result=secTree[key]    return resultdef storeTree(inputTree,filename):    import pickle    fw=open(filename,'w')    pickle.dump(inputTree,fw)    fw.close()    def grabTree(filename):    import pickle    fr=open(filename)    return pickle.load(fr)dataSet,label=createDataSet()trees=createTree(dataSet,label)print treesdataSet,label=createDataSet()r=classify(trees,label,[0,1])print r'''fr=open('lenses.txt')lenses=[inst.strip().split('\t') for inst in fr.readlines()]lenseLabel=['age', 'prescript', 'astigmatic', 'tearRate']trees=createTree(lenses,lenseLabel)lenseLabel=['age', 'prescript', 'astigmatic', 'tearRate']result=classify(trees,lenseLabel,['pre','myope','no','normal'])print result'''


0 0
原创粉丝点击