回归决策树

来源:互联网 发布:河北省中标数据网 编辑:程序博客网 时间:2024/05/22 11:05

决策树是处理分类的常用算法,但它也可以用来处理回归问题,其关键在于选择最佳分割点,基本思路是:遍历所有数据,尝试每个数据作为分割点,并计算此时左右两侧的数据的离差平方和,并从中找到最小值,然后找到离差平方和最小时对应的数据,它就是最佳分割点。sklearn.tree.DecisionTreeRegressor函数即利用决策树处理回归问题,树的深度越高拟合效果越好,也更容易发生过拟合。

回归决策树实践代码及效果如下:

#!/usr/bin/python# -*- coding:utf-8 -*-import numpy as npimport matplotlib.pyplot as pltfrom sklearn.tree import DecisionTreeRegressorif __name__ == "__main__":    N = 100    x = np.random.rand(N) * 6 - 3     # [-3,3)    x.sort()    y = np.sin(x) + np.random.randn(N) * 0.05    x = x.reshape(-1, 1)  # 转置后,得到N个样本,每个样本都是1维的    reg = DecisionTreeRegressor(criterion='mse', max_depth=9)    dt = reg.fit(x, y)    x_test = np.linspace(-3, 3, 50).reshape(-1, 1)    y_hat = dt.predict(x_test)    plt.plot(x, y, 'r*', ms=10, label='Actual')    plt.plot(x_test, y_hat, 'g-', linewidth=2, label='Predict')    plt.legend(loc='upper left')    plt.grid()    plt.show()    # 比较决策树的深度影响    depth = [2, 4, 6, 8, 10]    clr = 'rgbmy'    reg = [DecisionTreeRegressor(criterion='mse', max_depth=depth[0]),           DecisionTreeRegressor(criterion='mse', max_depth=depth[1]),           DecisionTreeRegressor(criterion='mse', max_depth=depth[2]),           DecisionTreeRegressor(criterion='mse', max_depth=depth[3]),           DecisionTreeRegressor(criterion='mse', max_depth=depth[4])]    plt.plot(x, y, 'k^', linewidth=2, label='Actual')    x_test = np.linspace(-3, 3, 50).reshape(-1, 1)    for i, r in enumerate(reg):        dt = r.fit(x, y)        y_hat = dt.predict(x_test)        plt.plot(x_test, y_hat, '-', color=clr[i], linewidth=2, label='Depth=%d' % depth[i])    plt.legend(loc='upper left')    plt.grid()    plt.show()

这里写图片描述

原创粉丝点击