cart未实现

来源:互联网 发布:网络洛克人 编辑:程序博客网 时间:2024/04/29 16:45
#coding=utf-8
import csv
import numpy as np
from numpy import *
import pandas as pd
from math import log
import operator
import matplotlib.pyplot as plt


file=open("1.csv",'r')
reader=csv.reader(file)
rows=[row for row in reader]
data=rows[1:11]
data_full=[datas[1:] for datas in data]
#data_full=data[0]
#data_full=data_full[1:]
lables=rows[0][1:-1]
lables_full=lables[:]
#print(data_full,lables)
dataSet=data_full


def calcGini(dataSet):  #计算数据集的基尼指数。
    global numEntries
    numEntries=len(dataSet)
    lableCounts={} #选择一个属性,记录每一个标记值的个数
    for featVec in dataSet:         #迭代取出每个对象
        currentLable=featVec[-1]    #对象的标记
        if currentLable not in lableCounts.keys():
            lableCounts[currentLable]=0   #记录每个标记的个数。
        else:
            lableCounts[currentLable]+=1
    Gini=1.0




    for key in lableCounts:     #迭代取出字典中的键。
        prob=float(lableCounts[key])/numEntries  #计算prob,即第key中标记的个数/总数
        Gini-=prob*prob
    return Gini


"@myself"
def Gini_index(dataSet,a):  #给定属性的gini指数
    
    indexa=lables.index(a)    #属性的角标
    aValue=[dataA[indexa] for dataA in dataSet]  #属性对应的属性值。
    
    attributeValueCount={}    #属性值个数
    attributeValueObject={}   #存放属性值对应的数据


    count=0
    for attributeValue in aValue:
        #indexAttributeValue=aValue.index(attributeValue)
        content=dataSet[count]   #取出attributeValue对应的数据,因为attributeValue是遍历的,所以逐个取出数据中的 数据即可
        
        
        if attributeValue not in attributeValueCount.keys():
            attributeValueObject[attributeValue]=[]   #如果这个属性值没存在过,为存储这个属性值对应的数据开辟一个list。
            attributeValueCount[attributeValue]=0   #记录这个属性值attributeValue有多少个
            
            attributeValueObject[attributeValue].extend([count])
            attributeValueObject[attributeValue][attributeValueCount[attributeValue]]=content #存放属性值attributeValue对应的数据
        else:
            attributeValueCount[attributeValue]+=1   
            
            attributeValueObject[attributeValue].extend([count])
            attributeValueObject[attributeValue][attributeValueCount[attributeValue]]=content #存放属性值attributeValue对应的数据


        count+=1
            
    Ginia=0.0
    for key in attributeValueCount:
        subDataSet=attributeValueObject[key] #根据属性得到数据集
        subGini=calcGini(subDataSet)
        
        Ginia+=(attributeValueCount[key]/numEntries)*subGini


    return Ginia
"@myself"           
def findMinGiniAttribute():  #找到最优基尼值属性,用于划分。
    listGini=[]
    for i in lables:
        listGini.append(Gini_index(data_full,i))
    minindex=listGini.index(min(listGini))
    minGiniAttribute=lables[minindex]
    return minGiniAttribute




def splitDataSet(dataSet,axis,value):  #axis是一个数,属性的列数,
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:        #判断属性值是否为value,
            reducedFeatVec=featVec[:axis]  #如果相等,
            reducedFeatVec.extend(featVec[axis+1:])#将属性所在列除去,其他的添加到retDataSet,因为这个离散属性用过之后不能再用了。
                                                   #通过这个属性分过之后,形成的子数据集,再通过其他属性进行划分
            retDataSet.append(reducedFeatVec)
    return retDataSet


#对连续变量划分数据集,direction规定划分的方向,
#决定是划分出小于value的数据样本还是大于value的数据样本数。
def splitContinuousDataSet(dataSet,axis,value,direction):
    retDataSet=[]
    for featVec in dataSet:
        if direction==0:
            if featVec[axis]>value:
                reducedFeatVec=featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
        else:
            if featVec[axis]<=value:
                reducedFeatVec=featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
    return retDataSet




#选择最好的数据集划分方式  
def chooseBestFeatureToSplit(dataSet,lables):  
    numFeatures=len(dataSet[0])-1  #属性集的最大角标
    bestGiniIndex=100000.0    #最大gini指数
    bestFeature=-1            #最好的属性
    bestSplitDict={}      #字典  
    for i in range(numFeatures):  #迭代取出每个属性,可能有多个连续属性,和多个离散属性,为了得到最好的划分属性
        featList=[example[i] for example in dataSet]  #迭代取出每个属性对应的值
        #对连续型特征进行处理  
        if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':  #如果属性是连续的
            #产生n-1个候选划分点  
            sortfeatList=sorted(featList) #对属性值进行排序 
            splitList=[]  
            for j in range(len(sortfeatList)-1):  
                splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0) #将每两个相邻属性值的平均数作为候选划分点 
              
            bestSplitGini=10000    #最好分离gini指数,初始化只要够大,是随便的,相当于定义了一个变量。
            slen=len(splitList)    #求出候选划分点的长度
            #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点  
            for j in range(slen):  
                value=splitList[j] #迭代取出候选划分点 
                newGiniIndex=0.0  
                subDataSet0=splitContinuousDataSet(dataSet,i,value,0)  #condition=0,选择大于value的子集
                subDataSet1=splitContinuousDataSet(dataSet,i,value,1)  
                prob0=len(subDataSet0)/float(len(dataSet))  
                newGiniIndex+=prob0*calcGini(subDataSet0)    #计算出这个划分点的gini指数 
                prob1=len(subDataSet1)/float(len(dataSet))  
                newGiniIndex+=prob1*calcGini(subDataSet1)  
                if newGiniIndex<bestSplitGini:   #如果新划分点的gini指数小于以前的最好划分点 
                    bestSplitGini=newGiniIndex    #将新gini指数给最好划分gini指数
                    bestSplit=j                   #记录下最好gini指数的下标。
            #用字典记录当前特征的最佳划分点  
            bestSplitDict[lables[i]]=splitList[bestSplit]  #lables[i]第i+1个属性。字典,键值对记录。以键取值 
              
            GiniIndex=bestSplitGini  #得到gini指数,


            
            
        #对离散型特征进行处理  
        else:  
            uniqueVals=set(featList)  #将离散属性值转化为集合  
            newGiniIndex=0.0          #定义一个新gini指数
            #计算该特征下每种划分的信息熵  
            for value in uniqueVals:  #迭代取出每个属性值
                subDataSet=splitDataSet(dataSet,i,value)  #值为value的数据集对象。去除i列的元素。因为这个属性用过之后,就不能再用了。
                                                          #连续值可以再用。
                prob=len(subDataSet)/float(len(dataSet))  
                newGiniIndex+=prob*calcGini(subDataSet)    #计算gini指数
            GiniIndex=newGiniIndex  
        if GiniIndex<bestGiniIndex:  
            bestGiniIndex=GiniIndex  #最好的gini指数  
            bestFeature=i    #记录最好的属性位置
            
    #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理  
    #即是否小于等于bestSplitValue  
    if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':        
        bestSplitValue=bestSplitDict[lables[bestFeature]]    #最好的划分值     
        lables[bestFeature]=lables[bestFeature]+'<='+str(bestSplitValue)  #最好的属性
        for i in range(shape(dataSet)[0]):  
            if dataSet[i][bestFeature]<=bestSplitValue:  
                dataSet[i][bestFeature]=1  #小于划分点,将属性值置1
            else:  
                dataSet[i][bestFeature]=0  #否则置0
    return bestFeature  #返回最好的划分属性位置。




#特征已经划分完,节点下的样本还没有统一值。则需要进行投票
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    value=max(classCount.values())#返回个数最多的类
    for k,v in classCount.items():
        if v==value:
            return k            #返回个数最多的类




#主程序,递归产生决策树
def createTree(dataSet,lables,data_full,lables_full):    #输入数据
    classList=[example[-1] for example in dataSet]    #类列表。
    if classList.count(classList[0])==len(classList):   #classList.count(classList[0])是lassList[0]在lassList中发生的次数。
        return classList[0]     #返回标记。
    if len(dataSet[0])==0:
        return majorityCnt(classList)  #返回类标记
    bestFeat=chooseBestFeatureToSplit(dataSet,lables)   #得到最好的划分属性位置。
    bestFeatlable=lables[bestFeat]   #得到最好的划分属性
    myTree={bestFeatlable:{}}#最好的划分属性作为键,一个空字典作为值
    featValues=[example[bestFeat] for example in dataSet] #最好划分属性对应的属性值
    uniqueVals=set(featValues)  #属性值转化为集合。
    if type(dataSet[0][bestFeat]).__name__=='str':  #如果属性为字符串
        currentLable=lables_full.index(lables[bestFeat]) #找出属性的角标
        featValuesFull=[example[currentLable] for example in data_full]#最好划分属性对应的属性值
        uniqueValsFull=set(featValuesFull)#转化为集合
    del(lables[bestFeat])#删除角标及元素


    for value in uniqueVals:  #迭代取出每个属性值
        sublables=lables[:]   #子属性
        if type(dataSet[0][bestFeat]).__name__=='str':#如果属性为字符串
            uniqueValsFull.remove(value)#删除属性值
        myTree[bestFeatlable][value]=createTree(splitDataSet(dataSet,bestFeat,value),sublables,data_full,lables_full)
         #创建树,以最好属性为横坐标,value为纵坐标,递归。
    if type(dataSet[0][bestFeat]).__name__=='str':  #如果属性为字符串
        for value in uniqueValsFull:                 #迭代取出每个属性值
            myTree[bestFeatlable][value]=majorityCnt(classList)   #将类赋值给树。 
    return myTree


plt.plot(createTree(data,lables,data,lables_full))
plt.show()


    
        










































        
0 0