连续数值属性的CART decision tree

来源:互联网 发布:淘宝下架宝贝找不到 编辑:程序博客网 时间:2024/05/24 04:18

划分数据集:Iris;

数据形如:      

 a    b    c    d           class
0    5.1  3.5  1.4  0.2     Iris-setosa
1    4.9  3.0  1.4  0.2     Iris-setosa
2    4.7  3.2  1.3  0.2     Iris-setosa
3    4.6  3.1  1.5  0.2     Iris-setosa
4    5.0  3.6  1.4  0.2     Iris-setosa,

一共四维属性外加class属性。

划分选择依据:基尼指数

连续数值的处理:对Iris中每一维的连续两个数值元素求平均值,构成(n-1)*4维的划分点集合;对每一维中的划分点集合迭代计算基尼指数,将最小值作为最优划分属性。

需要一提的是:做连续数值划分的决策树特别容易出现过拟合的情况,因为终止条件苛刻,在依据某一维属性比如a<7.2进行二维划分之后,下次选择最优划分属性时a并不像离散决策树中可以排除,因为a还可以变成a<3.5,所以决策树的终止条件要依靠剪枝策略来完善。但是在做出划分结果的时候我已经很开心了,迫不及待的贴出来,剪枝的事情,下次再做吧。

贴代码:

# -*- coding: utf-8 -*-"""Created on Wed Sep 20 11:16:37 2017@author: wjw"""import numpy as npimport pandas as pddef readText(filePath):        lines = open(filePath,'r').readlines()    data = []        for line in lines:        dataList = line.split(',')        data.append([float(dataList[0]),float(dataList[1]),float(dataList[2]),                     float(dataList[3]),dataList[4].split("\n")[0]])            data = pd.DataFrame(data,columns=["a","b","c","d","class"])    return data"""       a    b    c    d           class0    5.1  3.5  1.4  0.2     Iris-setosa1    4.9  3.0  1.4  0.2     Iris-setosa2    4.7  3.2  1.3  0.2     Iris-setosa3    4.6  3.1  1.5  0.2     Iris-setosa4    5.0  3.6  1.4  0.2     Iris-setosa"""def binSplitData(data,feature,value):#将数据二分开        data0 = data[data[feature]<=value]    data1 = data[data[feature]>value]    binData = [data0,data1] #binData是一个三维list        return binDatadef chooseBestFeatureToSplit(data):    avg_set = process(data)    gini = calGiniIndex(data,avg_set)    min_avg,minColumn = getMINGini(gini,avg_set)        return min_avg,minColumndef tree(data):            countList = data.groupby('class').count().iloc[:,0] #得到data数据的class统计量    if countList[0]==data.shape[0]: #如果样本中的所有元素属于同一类别,把这些数据从要分类的中删除        print("所属类别是:%s"%(data.iloc[0,-1]))        return #data.iloc[0,-1] #返回类别        min_avg,minColumn = chooseBestFeatureToSplit(data)        print('现在判断属性%s'%(minColumn))        binData = binSplitData(data,minColumn,min_avg)    for i in range(2):                if i==0:            print("if属性%s<=%s"%(minColumn,min_avg)) #下一步划分的前提条件        elif i==1:            print("if属性%s>%s"%(minColumn,min_avg))                tree(binData[i])    return def getMINGini(gini,avg_set):    minV = 10.        for column_index in range(gini.shape[1]):        gini_column = gini.iloc[:,column_index]        newmin = min(gini_column)        if newmin < minV:            minV = newmin             minColumn = column_index        min_avg = avg_set[minColumn][gini.iloc[:,minColumn][gini.iloc[:,minColumn]==minV].index.tolist()]        return min_avg[0],gini.columns[minColumn]def calGiniIndex(data,avg_set): #计算avg_set对应的每一维属性的gini指数,一并返回。    giniSet=[]        for index in range(avg_set.shape[0]):        d = data.iloc[:,index]#iloc,通过索引得到数据        subavg_set = avg_set[index]        sub_giniSet = []        for avg in subavg_set:                        ndata = data[d<=avg]  #得到小于平均数的数据            pdata = data[d>avg]            subdata =  [ndata,pdata]                                    gini = 0            for dd in subdata:                sum_cr = 0                count_result = dd.iloc[:,[index,-1]].groupby('class').count().iloc[:,0]                for cr in count_result:                    sum_cr += (cr/dd.shape[0])**2                for cr in count_result:                    gini += (cr/data.shape[0])*(1-sum_cr)            sub_giniSet.append(gini)        giniSet.append(sub_giniSet)        giniSet = pd.DataFrame(np.array(giniSet).T,columns=list('abcd'))        return giniSet        def process(data): #得到包含n-1个元素的连续值候选划分集合    avg_set=[]    for i in range(data.shape[1]-1):        subavg_set = []        d = data.iloc[:,i]        sorted_data = np.sort(d)#从小到大排序,直接返回的是array        for i in range(0,sorted_data.size-1):            subavg_set.append((sorted_data[i]+sorted_data[i+1])/2)        avg_set.append(subavg_set)    return np.array(avg_set)if __name__ == "__main__":    filePath = r"E:\data\iris.txt"    data= readText(filePath)    tree(data)    
运行结果:

现在判断属性cif属性c<1.9所属类别是:Iris-setosaif属性c>1.9现在判断属性dif属性d<1.7现在判断属性cif属性c<4.9现在判断属性dif属性d<1.6所属类别是:Iris-versicolorif属性d>1.6所属类别是:Iris-virginicaif属性c>4.9现在判断属性dif属性d<1.5所属类别是:Iris-virginicaif属性d>1.5现在判断属性aif属性a<6.95所属类别是:Iris-versicolorif属性a>6.95所属类别是:Iris-virginicaif属性d>1.7现在判断属性cif属性c<4.8现在判断属性aif属性a<5.95所属类别是:Iris-versicolorif属性a>5.95所属类别是:Iris-virginicaif属性c>4.8所属类别是:Iris-virginica

怎么样,还不错吧,我觉得ok~

剪枝优化下一篇做吧~





原创粉丝点击