决策树
来源:互联网 发布:全国各大高校校花知乎 编辑:程序博客网 时间:2024/06/10 15:09
决策树的实现代码
class DecisionNode(object): def __init__(self, feature_i=None, threshold=None, value = None, true_branch=None, false_branch=None): self.feature_i = feature_i self.threshold = threshold self.value = value self.true_branch = true_branch self.false_branch = false_branchclass DecisionTree(object): def __init__(self, min_sample_split=2, min_impurity=1e-7, max_depth=float("inf")): self.root = None self.min_sample_split = min_sample_split self.min_impurity = min_impurity self.max_depth = max_depth ### Function to calculate impurity self._impurity_caculation = None ### Function to determine value of leaf node self._leaf_value_caculation = None def fit(self, X, y): self.root = self._build_tree(X,y) def _build_tree(self, X, y, current_depth=0): largest_impurity = 0 best_criteria = None # Feature index and threshold best_sets = None # Subsets of the data X_y = np.concatenate((X, y), axis=1) n_samples , n_features = np.shape(X) if n_samples >= self.min_sample_split and current_depth <=self.max_depth: for feature_i in range(n_features): unique_values = np.unique(feature_values) for threshold in unique_values: Xy1, Xy2 = divide_on_feature(X_y, feature_i, threshold) y1 = Xy1[:, n_features:] y2 = Xy2[:, n_features:] impurity = self._impurity_caculation(y , y1, y2) if impurity > largest_impurity: largest_impurity = impurity best_criteria = {'feature_i': feature_i, 'threshold': threshold} best_sets = {'leftX': Xy1[:, :n_features], 'lefty': Xy1[:, n_features:], 'rightX': Xy2[:, :n_features], 'righty': Xy2[:,n_features:]} if largest_impurity>self.min_impurity: true_branch = self._build_tree(best_sets['leftX'], best_sets['lefty'],current_depth+1) false_branch = self._build_tree(best_sets['rightX'], best_sets['righty'],current_depth+1) return DecisionNode(feature_i = best_criteria['feature_i'], threshold = best_criteria['threshold'], value = None, true_branch=true_branch, false_branch=false_branch) def predict_value(self, x, tree=None): if tree is None: tree = self.root if tree.value is not None: return tree.value feature_value = x[tree.feature_i] branch = tree.false_branch if isinstance(feature_value ,int) or isinstance(feature_value, float): if feature_value >= tree.threshold: branch= tree.true_branch elif feature_value == tree.threshold: branch = tree.true_branch return self.predict_value(x, branch) def predict(self, X): y_pred = [] for x in X: y_pred.append(self.predict_value(x)) return y_pred
可以根据ID3还是CART自定义self._impurity_calculation函数体,并继承上述类
class ClassificationTree(DecisionTree): #### here is ID3 def _calculate_information_gain(self, y, y1, y2): # Calculate information gain p = len(y1) / len(y) entropy = calculate_entropy(y) info_gain = entropy - p * \ calculate_entropy(y1) - (1 - p) * \ calculate_entropy(y2) ### entropy calculation omitted return info_gain def _majority_vote(self, y): most_common = None max_count = 0 for label in np.unique(y): # Count number of occurences of samples with label count = len(y[y == label]) if count > max_count: most_common = label max_count = count return most_common def fit(self, X, y): self._impurity_calculation = self._calculate_information_gain self._leaf_value_calculation = self._majority_vote super(ClassificationTree, self).fit(X, y)
0 0
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- EOJ 3263丽娃河的狼人传说(贪心)
- spring使用jpa进行update操作
- java实现定时任务的三种方法
- v-bind和v-on
- 多线程(同步锁)
- 决策树
- 蓝桥杯训练:爆搜——四平方和
- jQuery实现全选取消反选
- 【十分钟读懂系列】之什么是SLF,PSL,MLF,SLO?
- WebSocket 是什么原理?为什么可以实现持久连接?
- centos 7.1 apache 源码编译安装
- 聊聊跑步
- java中map的四种取值方式
- C/C++: 实现加减乘除。