决策树转规则

来源:互联网 发布:米尔斯海默 知乎 编辑:程序博客网 时间:2024/06/01 10:35


有些决策、分类的规则,手写比较麻烦,但用机器学习模型,比如LR搞的话又比较难运营和理解。这时,通过少node的决策树模型来做,并将其生成规则,是一个折衷的解决方案。


import numpy as npfrom sklearn.tree import DecisionTreeRegressorfrom sklearn.tree import _treetrainx = []trainy = []with open('vm06.xy') as fd:    fdl = fd.readline()    while len(fdl) > 0:        v = fdl.split(' ')        trainx.append(np.asarray([float(v[2]), float(v[3]), float(v[4]), float(v[5].strip())])) #v[2]~v[5]是特征        trainy.append(float(v[1]) > 60) #v[1]的值用于分类,大于60为True,小于等于60为False        fdl = fd.readline()regressor = DecisionTreeRegressor(max_leaf_nodes=8)regressor.fit(np.asarray(trainx), np.asarray(trainy))res = regressor.predict(trainx[39:51])print (res, trainy[39:51])def tree_to_code(tree, feature_names):    tree_ = tree.tree_    feature_name = [        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"        for i in tree_.feature    ]    print ("def tree({}):".format(", ".join(feature_names)))    def recurse(node, depth):        indent = "  " * depth        if tree_.feature[node] != _tree.TREE_UNDEFINED:            name = feature_name[node]            threshold = tree_.threshold[node]            print ("{}if {} <= {}:".format(indent, name, threshold))            recurse(tree_.children_left[node], depth + 1)            print ("{}else:  # if {} > {}".format(indent, name, threshold))            recurse(tree_.children_right[node], depth + 1)        else:            print ("{}return {}".format(indent, tree_.value[node]))    recurse(0, 1)tree_to_code (regressor, ["length", "width", "height", "fps"])

输出


def tree(length, width, height, fps):  if length <= 205.5:    if length <= 91.5:      return [[ 0.00090733]]    else:  # if length > 91.5      if width <= 1703.0:        return [[ 0.02891943]]      else:  # if width > 1703.0        return [[ 0.81340058]]  else:  # if length > 205.5    if width <= 859.0:      if length <= 795.0:        return [[ 0.05918367]]      else:  # if length > 795.0        return [[ 0.75434531]]    else:  # if width > 859.0      if height <= 702.0:        if length <= 596.5:          return [[ 0.12064343]]        else:  # if length > 596.5          return [[ 0.93028025]]      else:  # if height > 702.0        return [[ 0.892728]]