连续数值属性的CART decision tree
来源:互联网 发布:淘宝下架宝贝找不到 编辑:程序博客网 时间:2024/05/24 04:18
划分数据集:Iris;
数据形如:
a b c d class
0 5.1 3.5 1.4 0.2 Iris-setosa
1 4.9 3.0 1.4 0.2 Iris-setosa
2 4.7 3.2 1.3 0.2 Iris-setosa
3 4.6 3.1 1.5 0.2 Iris-setosa
4 5.0 3.6 1.4 0.2 Iris-setosa,
一共四维属性外加class属性。
划分选择依据:基尼指数
连续数值的处理:对Iris中每一维的连续两个数值元素求平均值,构成(n-1)*4维的划分点集合;对每一维中的划分点集合迭代计算基尼指数,将最小值作为最优划分属性。
需要一提的是:做连续数值划分的决策树特别容易出现过拟合的情况,因为终止条件苛刻,在依据某一维属性比如a<7.2进行二维划分之后,下次选择最优划分属性时a并不像离散决策树中可以排除,因为a还可以变成a<3.5,所以决策树的终止条件要依靠剪枝策略来完善。但是在做出划分结果的时候我已经很开心了,迫不及待的贴出来,剪枝的事情,下次再做吧。
贴代码:
# -*- coding: utf-8 -*-"""Created on Wed Sep 20 11:16:37 2017@author: wjw"""import numpy as npimport pandas as pddef readText(filePath): lines = open(filePath,'r').readlines() data = [] for line in lines: dataList = line.split(',') data.append([float(dataList[0]),float(dataList[1]),float(dataList[2]), float(dataList[3]),dataList[4].split("\n")[0]]) data = pd.DataFrame(data,columns=["a","b","c","d","class"]) return data""" a b c d class0 5.1 3.5 1.4 0.2 Iris-setosa1 4.9 3.0 1.4 0.2 Iris-setosa2 4.7 3.2 1.3 0.2 Iris-setosa3 4.6 3.1 1.5 0.2 Iris-setosa4 5.0 3.6 1.4 0.2 Iris-setosa"""def binSplitData(data,feature,value):#将数据二分开 data0 = data[data[feature]<=value] data1 = data[data[feature]>value] binData = [data0,data1] #binData是一个三维list return binDatadef chooseBestFeatureToSplit(data): avg_set = process(data) gini = calGiniIndex(data,avg_set) min_avg,minColumn = getMINGini(gini,avg_set) return min_avg,minColumndef tree(data): countList = data.groupby('class').count().iloc[:,0] #得到data数据的class统计量 if countList[0]==data.shape[0]: #如果样本中的所有元素属于同一类别,把这些数据从要分类的中删除 print("所属类别是:%s"%(data.iloc[0,-1])) return #data.iloc[0,-1] #返回类别 min_avg,minColumn = chooseBestFeatureToSplit(data) print('现在判断属性%s'%(minColumn)) binData = binSplitData(data,minColumn,min_avg) for i in range(2): if i==0: print("if属性%s<=%s"%(minColumn,min_avg)) #下一步划分的前提条件 elif i==1: print("if属性%s>%s"%(minColumn,min_avg)) tree(binData[i]) return def getMINGini(gini,avg_set): minV = 10. for column_index in range(gini.shape[1]): gini_column = gini.iloc[:,column_index] newmin = min(gini_column) if newmin < minV: minV = newmin minColumn = column_index min_avg = avg_set[minColumn][gini.iloc[:,minColumn][gini.iloc[:,minColumn]==minV].index.tolist()] return min_avg[0],gini.columns[minColumn]def calGiniIndex(data,avg_set): #计算avg_set对应的每一维属性的gini指数,一并返回。 giniSet=[] for index in range(avg_set.shape[0]): d = data.iloc[:,index]#iloc,通过索引得到数据 subavg_set = avg_set[index] sub_giniSet = [] for avg in subavg_set: ndata = data[d<=avg] #得到小于平均数的数据 pdata = data[d>avg] subdata = [ndata,pdata] gini = 0 for dd in subdata: sum_cr = 0 count_result = dd.iloc[:,[index,-1]].groupby('class').count().iloc[:,0] for cr in count_result: sum_cr += (cr/dd.shape[0])**2 for cr in count_result: gini += (cr/data.shape[0])*(1-sum_cr) sub_giniSet.append(gini) giniSet.append(sub_giniSet) giniSet = pd.DataFrame(np.array(giniSet).T,columns=list('abcd')) return giniSet def process(data): #得到包含n-1个元素的连续值候选划分集合 avg_set=[] for i in range(data.shape[1]-1): subavg_set = [] d = data.iloc[:,i] sorted_data = np.sort(d)#从小到大排序,直接返回的是array for i in range(0,sorted_data.size-1): subavg_set.append((sorted_data[i]+sorted_data[i+1])/2) avg_set.append(subavg_set) return np.array(avg_set)if __name__ == "__main__": filePath = r"E:\data\iris.txt" data= readText(filePath) tree(data)运行结果:
现在判断属性cif属性c<1.9所属类别是:Iris-setosaif属性c>1.9现在判断属性dif属性d<1.7现在判断属性cif属性c<4.9现在判断属性dif属性d<1.6所属类别是:Iris-versicolorif属性d>1.6所属类别是:Iris-virginicaif属性c>4.9现在判断属性dif属性d<1.5所属类别是:Iris-virginicaif属性d>1.5现在判断属性aif属性a<6.95所属类别是:Iris-versicolorif属性a>6.95所属类别是:Iris-virginicaif属性d>1.7现在判断属性cif属性c<4.8现在判断属性aif属性a<5.95所属类别是:Iris-versicolorif属性a>5.95所属类别是:Iris-virginicaif属性c>4.8所属类别是:Iris-virginica
怎么样,还不错吧,我觉得ok~
剪枝优化下一篇做吧~阅读全文
1 0
- 连续数值属性的CART decision tree
- CART Decision Tree and two pruning theory
- Decision Tree 决策树 - ID3, C45, C50, CART...
- 决策树(decision tree)的自我理解 (下) 关于剪枝和连续值缺失值处理
- 决策树 (Decision Tree) 进阶应用 CART剪枝方法及Python实现方式
- CART分类与回归树与GBDT(Gradient Boost Decision Tree)
- decision tree
- decision tree
- decision tree
- Decision Tree
- Decision Tree
- Decision Tree
- Decision Tree
- Decision Tree
- Decision Tree
- Decision Tree
- Decision Tree
- Decision Tree
- bootstrap实现自定义按钮导出Excel表格
- YUV转RGB
- Communications link failure的解决办法
- 轮盘赌算法
- js中for in循环对象时的取值处理
- 连续数值属性的CART decision tree
- spring+websocket整合(springMVC+spring+MyBatis即SSM框架和websocket技术的整合)
- Spring Boot整合Quartz实现定时任务表配置
- centos增加yum源方法
- Leetcode之Missing Number 问题
- 数据结构--线性链式表倒数第K项
- dubbo 配置文件详解
- 解决window 热键(快捷键)占用问题
- 线程池编程示例