【Python学习系列十】Python机器学习库scikit-learn实现Decision Trees案例

来源:互联网 发布:淘宝客退款佣金怎么算 编辑:程序博客网 时间:2024/06/05 15:36

学习网址:http://scikit-learn.org/stable/modules/tree.html

scikit-learn这个官网很好,里面有算法案例也有算法原理说明。

案例代码:

# -*- coding: utf-8 -*-__author__ = 'Jason.F'#http://scikit-learn.org/stable/modules/tree.htmlimport numpy as npimport matplotlib.pyplot as pltfrom sklearn.tree import DecisionTreeRegressor# Create a random datasetrng = np.random.RandomState(1)X = np.sort(200 * rng.rand(100, 1) - 100, axis=0)y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).Ty[::5, :] += (0.5 - rng.rand(20, 2))# Fit regression modelregr_1 = DecisionTreeRegressor(max_depth=2)regr_2 = DecisionTreeRegressor(max_depth=5)regr_3 = DecisionTreeRegressor(max_depth=8)regr_1.fit(X, y)regr_2.fit(X, y)regr_3.fit(X, y)# PredictX_test = np.arange(-100.0, 100.0, 0.01)[:, np.newaxis]y_1 = regr_1.predict(X_test)y_2 = regr_2.predict(X_test)y_3 = regr_3.predict(X_test)# Plot the resultsplt.figure()s = 50plt.scatter(y[:, 0], y[:, 1], c="navy", s=s, label="data")plt.scatter(y_1[:, 0], y_1[:, 1], c="cornflowerblue", s=s, label="max_depth=2")plt.scatter(y_2[:, 0], y_2[:, 1], c="c", s=s, label="max_depth=5")plt.scatter(y_3[:, 0], y_3[:, 1], c="orange", s=s, label="max_depth=8")plt.xlim([-6, 6])plt.ylim([-6, 6])plt.xlabel("target 1")plt.ylabel("target 2")plt.title("Multi-output Decision Tree Regression")plt.legend()plt.show()

结果:


对scikit-learn库大体上拿决策树和支持向量机来了解,后面就是要具体应用。

阅读全文
0 0
原创粉丝点击