回归决策树
来源:互联网 发布:河北省中标数据网 编辑:程序博客网 时间:2024/05/22 11:05
决策树是处理分类的常用算法,但它也可以用来处理回归问题,其关键在于选择最佳分割点,基本思路是:遍历所有数据,尝试每个数据作为分割点,并计算此时左右两侧的数据的离差平方和,并从中找到最小值,然后找到离差平方和最小时对应的数据,它就是最佳分割点。sklearn.tree.DecisionTreeRegressor函数即利用决策树处理回归问题,树的深度越高拟合效果越好,也更容易发生过拟合。
回归决策树实践代码及效果如下:
#!/usr/bin/python# -*- coding:utf-8 -*-import numpy as npimport matplotlib.pyplot as pltfrom sklearn.tree import DecisionTreeRegressorif __name__ == "__main__": N = 100 x = np.random.rand(N) * 6 - 3 # [-3,3) x.sort() y = np.sin(x) + np.random.randn(N) * 0.05 x = x.reshape(-1, 1) # 转置后,得到N个样本,每个样本都是1维的 reg = DecisionTreeRegressor(criterion='mse', max_depth=9) dt = reg.fit(x, y) x_test = np.linspace(-3, 3, 50).reshape(-1, 1) y_hat = dt.predict(x_test) plt.plot(x, y, 'r*', ms=10, label='Actual') plt.plot(x_test, y_hat, 'g-', linewidth=2, label='Predict') plt.legend(loc='upper left') plt.grid() plt.show() # 比较决策树的深度影响 depth = [2, 4, 6, 8, 10] clr = 'rgbmy' reg = [DecisionTreeRegressor(criterion='mse', max_depth=depth[0]), DecisionTreeRegressor(criterion='mse', max_depth=depth[1]), DecisionTreeRegressor(criterion='mse', max_depth=depth[2]), DecisionTreeRegressor(criterion='mse', max_depth=depth[3]), DecisionTreeRegressor(criterion='mse', max_depth=depth[4])] plt.plot(x, y, 'k^', linewidth=2, label='Actual') x_test = np.linspace(-3, 3, 50).reshape(-1, 1) for i, r in enumerate(reg): dt = r.fit(x, y) y_hat = dt.predict(x_test) plt.plot(x_test, y_hat, '-', color=clr[i], linewidth=2, label='Depth=%d' % depth[i]) plt.legend(loc='upper left') plt.grid() plt.show()
阅读全文
0 0
- 回归决策树
- 回归- 决策树
- 回归决策树
- 比较决策树和回归
- 决策树&逻辑回归
- sklearn中的回归决策树
- 决策树回归R语言实现
- CART决策树分类和回归
- 决策树之分类回归树(C&RT)
- (笔记)列联表分析,Logistic回归,到决策树
- 决策树 逻辑回归 KNN 的原理
- 决策树分类与回归(一)
- 逻辑斯蒂回归和决策树
- 机器学习决策树:sklearn分类和回归
- 决策树回归:不掉包源码实现
- 逻辑回归与决策树在分类上的一些区别
- 逻辑回归与决策树在分类上的一些区别
- 逻辑回归与决策树在分类上的一些区别
- spring+springMVC+hibernate事务配置
- 严蔚敏版数据结构学习笔记(1):线性表的顺序表示和实现
- HDU
- Spark性能优化指南——高级篇
- tcp协议三次握手、四次挥手及其他
- 回归决策树
- 掌控之外,收获之中
- MongoDB主从模式,复制模式比较
- appium自动化参考博客
- 实施定量风险分析的工具 EMV分析与决策树学习
- uva10655(矩阵快速幂)
- CentOS Linux 6.8 tomcat的启动关闭
- 对于json的理解
- join()方法