CART决策树的Python实现
来源:互联网 发布:windows7内存优化 编辑:程序博客网 时间:2024/05/15 16:02
完整的代码请见:https://github.com/WiseDoge/ML-by-Python
CART决策树分类器
from collections import defaultdictimport numpy as npclass TreeNode(object): """决策树节点""" def __init__(self, **kwargs): ''' attr_index: 属性编号 attr: 属性值 label: 类别(y) left_chuld: 左子结点 right_child: 右子节点 ''' self.attr_index = kwargs.get('attr_index') self.attr = kwargs.get('attr') self.label = kwargs.get('label') self.left_child = kwargs.get('left_child') self.right_child = kwargs.get('right_child')class DecisionTreeClassifier(object): """ 决策树分类器。 本算法采用的是分类与回归树(classification and regression tree, CART) """ def __init__(self): # 决策树根节点 self.root = None def gini(self, cluster): ''' :param cluster: 训练集的一个子集 :return: 数据集的基尼系数 求给定数据集的基尼系数 ''' p = defaultdict(int) for line in cluster: p[line[-1]] += 1 temp = 1.0 for k, v in p.items(): temp -= (v / len(cluster)) ** 2 return temp def gini_index(self, cluster, attr_index): ''' :param cluster: 训练集的一个子集 :param attr_index: 特征编号(第N个特征) :return: 第N个特征的特征值, 该值的基尼指数 返回给定列标号下的最优切分属性和该属性的基尼指数 ''' p = defaultdict(list) for line in cluster: p[line[attr_index]].append(line) attr_gini = {} for k, v in p.items(): els = [] for k1, v1 in p.items(): if k1 == k: continue els.extend(v) count = (self.gini(v) * len(v) + self.gini(els) * len(els)) / len(cluster) attr_gini[k] = count attr = min(attr_gini, key=attr_gini.get) return attr, attr_gini[attr] def devide_set(self, cluster, index, attr): ''' :param cluster: 给定集合(为训练集的一个子集) :param index: 特征编号 :param attr: 特征值 :return: 左半部分,右半部分 将给定集合切分为两部分返回,第index个特征的特征值等于attr的为一组 不等于attr的为一组。 ''' left = [] right = [] for line in cluster: if line[index] == attr: left.append(line) else: right.append(line) return np.array(left), np.array(right) def get_best_index(self, cluster, attr_indexs): ''' :param cluster: 给定数据集 :param attr_indexs: 给定的可供切分的特征编号的集合 :return: 最佳切分点,最佳切分变量 求给定切分点集合中的最佳切分点和其对应的最佳切分变量 ''' p = {} for attr_index in attr_indexs: p[attr_index] = (self.gini_index(cluster, attr_index)) attr_index = min(p, key=lambda x: p.get(x)[1]) attr = p[attr_index][0] return attr_index, attr def build_tree(self, cluster, attr_indexs): ''' :param cluster: 给定数据集 :param attr_indexs: 给定的可供切分的特征编号的集合 :return: 一个决策树结点 递归构建决策树 ''' flag = cluster[0, -1] for i in cluster[:, -1]: if i != flag: break else: return TreeNode(label=flag) if not attr_indexs: p = defaultdict(int) for line in cluster: p[line[-1]] += 1 return TreeNode(label=max(p, key=p.get)) for i in attr_indexs: flag = cluster[i][0] f = False for j in cluster[:, i]: if j != flag: f = True break if f: break else: p = defaultdict(int) for line in cluster: p[line[-1]] += 1 return TreeNode(label=max(p, key=p.get)) attr_index, attr = self.get_best_index(cluster, attr_indexs) left, right = self.devide_set(cluster, attr_index, attr) new_attr_indexs = attr_indexs - set([attr_index]) left_branch = self.build_tree(left, new_attr_indexs) right_branch = self.build_tree(right, new_attr_indexs) return TreeNode(left_child=left_branch, right_child=right_branch, attr_index=attr_index, attr=attr) def fit(self, train_x, train_y): ''' :param train_x: 训练集合X :param train_y: 训练集合Y(target) :return: None 拟合决策树 ''' attr_indexs = set(range(train_x.shape[1])) self.train_x = np.c_[train_x, train_y] self.root = self.build_tree(self.train_x, attr_indexs) def predict_one(self, x): ''' :param x: 待预测的样本X :return: X所属的类别 预测单个值 ''' node_p = self.root while node_p.label == None: if x[node_p.attr_index] == node_p.attr: node_p = node_p.left_child else: node_p = node_p.right_child return node_p.label def predict(self, test_x): ''' :param test_x: 测试集 :return: 测试集样本的类别集合 预测多个值 ''' return np.array([self.predict_one(x) for x in test_x])
0 0
- CART决策树的Python实现
- 机器学习:决策树ID3\C4.5\CART\随机森林总结及python上的实现 (2)
- 机器学习算法的Python实现 (3):CART决策树与剪枝处理
- 机器学习算法的Python实现 (3):CART决策树与剪枝处理
- 决策树之CART算法原理及python实现
- 《统计学习方法》 决策树 CART生成算法 分类树 Python实现
- 《统计学习方法》 决策树 CART生成算法 回归树 Python实现
- 模式识别:分类回归决策树CART的研究与实现
- 模式识别十一--分类回归决策树CART的研究与实现
- CART决策树的sklearn实现及其GraphViz可视化
- 决策树的Python实现
- 决策树的python实现
- 决策树的python实现
- 决策树的python实现
- python cart算法的简单实现
- 决策树CART
- 决策树CART
- 决策树CART
- JAVA系书单
- boost smart_ptr -> scoped_ptr
- 卡片式布局 MD风格设计 卡片式背景
- nodejs-fs模块
- HttpUrlConnection使用时遇到的问题
- CART决策树的Python实现
- ElasticSearch安装
- 1119. Pre- and Post-order Traversals (30)
- 欢迎使用CSDN-markdown编辑器
- DB2中storage group的概念,以及创建各种类型的表空间
- Android 客户端扫描网页端二维码实现登录
- listview 分页加载
- 各种jar包下载http://www.java2s.com/Code/Jar/h/Catalogh.htm
- golang接口赋值操练