决策树

来源:互联网 发布:全国各大高校校花知乎 编辑:程序博客网 时间:2024/06/10 15:09

决策树的实现代码

class DecisionNode(object):    def __init__(self, feature_i=None, threshold=None,                 value = None, true_branch=None, false_branch=None):        self.feature_i = feature_i        self.threshold = threshold        self.value = value        self.true_branch = true_branch        self.false_branch = false_branchclass DecisionTree(object):    def __init__(self, min_sample_split=2, min_impurity=1e-7,                 max_depth=float("inf")):        self.root = None        self.min_sample_split = min_sample_split        self.min_impurity = min_impurity        self.max_depth = max_depth        ### Function to calculate impurity        self._impurity_caculation = None        ### Function to determine value of leaf node        self._leaf_value_caculation = None    def fit(self, X, y):        self.root = self._build_tree(X,y)    def _build_tree(self, X, y, current_depth=0):        largest_impurity = 0        best_criteria = None   # Feature index and threshold        best_sets = None       # Subsets of the data        X_y = np.concatenate((X, y), axis=1)        n_samples , n_features = np.shape(X)        if n_samples >= self.min_sample_split and current_depth <=self.max_depth:            for feature_i in range(n_features):                unique_values = np.unique(feature_values)                for threshold in unique_values:                    Xy1, Xy2 = divide_on_feature(X_y, feature_i, threshold)                    y1 = Xy1[:, n_features:]                    y2 = Xy2[:, n_features:]                    impurity = self._impurity_caculation(y , y1, y2)                    if impurity > largest_impurity:                        largest_impurity = impurity                        best_criteria = {'feature_i': feature_i,                                         'threshold': threshold}                        best_sets = {'leftX': Xy1[:, :n_features],                                     'lefty': Xy1[:, n_features:],                                     'rightX': Xy2[:, :n_features],                                     'righty': Xy2[:,n_features:]}        if largest_impurity>self.min_impurity:            true_branch = self._build_tree(best_sets['leftX'], best_sets['lefty'],current_depth+1)            false_branch = self._build_tree(best_sets['rightX'], best_sets['righty'],current_depth+1)            return  DecisionNode(feature_i = best_criteria['feature_i'],                                 threshold = best_criteria['threshold'],                                 value = None,                                 true_branch=true_branch,                                 false_branch=false_branch)        def predict_value(self, x, tree=None):            if tree is None:                tree = self.root            if tree.value is not None:                return tree.value            feature_value = x[tree.feature_i]            branch = tree.false_branch            if isinstance(feature_value ,int) or isinstance(feature_value, float):                if feature_value >= tree.threshold:                    branch= tree.true_branch            elif feature_value == tree.threshold:                branch = tree.true_branch            return self.predict_value(x, branch)        def predict(self, X):            y_pred = []            for x in X:                y_pred.append(self.predict_value(x))            return y_pred

可以根据ID3还是CART自定义self._impurity_calculation函数体,并继承上述类

class ClassificationTree(DecisionTree):    #### here is ID3    def _calculate_information_gain(self, y, y1, y2):        # Calculate information gain        p = len(y1) / len(y)        entropy = calculate_entropy(y)        info_gain = entropy - p * \            calculate_entropy(y1) - (1 - p) * \            calculate_entropy(y2) ### entropy calculation omitted         return info_gain    def _majority_vote(self, y):        most_common = None        max_count = 0        for label in np.unique(y):            # Count number of occurences of samples with label            count = len(y[y == label])            if count > max_count:                most_common = label                max_count = count        return most_common    def fit(self, X, y):        self._impurity_calculation = self._calculate_information_gain        self._leaf_value_calculation = self._majority_vote        super(ClassificationTree, self).fit(X, y)
0 0
原创粉丝点击