scikit-learn中gridSearchCV 的使用

来源:互联网 发布:超图软件工资怎么样 编辑:程序博客网 时间:2024/06/02 01:58

GridSearchCV使用介绍

通常算法不够好,需要调试参数时必不可少。比如SVM的惩罚因子C,核函数kernel,gamma参数等,对于不同的数据使用不同的参数,结果效果可能差1-5个点,sklearn为我们提供专门调试参数的函数grid_search。

函数介绍

class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, verbose=0, pre_dispatch=‘2*n_jobs’, error_score=’raise’, return_train_score=’warn’)

参数介绍:

  • estimator —— 模型
  • param_grid —— dict or list of dictionaries
  • scoring : 评分函数
  • fit_params : dict, optional
  • n_jobs : 并行任务个数,int, default=1
  • pre_dispatch : int, or string, optional ‘2*n_jobs’
  • iid : boolean, default=True
  • cv : int, 交叉验证,默认3
  • refit : boolean, or string, default=True
  • verbose : integer
  • error_score : ‘raise’ (default) or numeric

下面是官网的一个使用案例

>>> from sklearn import svm, datasets>>> from sklearn.model_selection import GridSearchCV>>> iris = datasets.load_iris()>>> parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}>>> svc = svm.SVC()>>> clf = GridSearchCV(svc, parameters)>>> clf.fit(iris.data, iris.target)...                             GridSearchCV(cv=None, error_score=...,       estimator=SVC(C=1.0, cache_size=..., class_weight=..., coef0=...,                     decision_function_shape='ovr', degree=..., gamma=...,                     kernel='rbf', max_iter=-1, probability=False,                     random_state=None, shrinking=True, tol=...,                     verbose=False),       fit_params=None, iid=..., n_jobs=1,       param_grid=..., pre_dispatch=..., refit=..., return_train_score=...,       scoring=..., verbose=...)>>> sorted(clf.cv_results_.keys())...                             ['mean_fit_time', 'mean_score_time', 'mean_test_score',... 'mean_train_score', 'param_C', 'param_kernel', 'params',... 'rank_test_score', 'split0_test_score',... 'split0_train_score', 'split1_test_score', 'split1_train_score',... 'split2_test_score', 'split2_train_score',... 'std_fit_time', 'std_score_time', 'std_test_score', 'std_train_score'...]

实例介绍

下面通过官网关于svr的简单的应用介绍一下在svr中的应用
官网

# Authors: Jan Hendrik Metzen <jhm@informatik.uni-bremen.de># License: BSD 3 clausefrom __future__ import divisionimport timeimport numpy as npfrom sklearn.svm import SVRfrom sklearn.model_selection import GridSearchCVfrom sklearn.model_selection import learning_curvefrom sklearn.kernel_ridge import KernelRidgeimport matplotlib.pyplot as pltrng = np.random.RandomState(0)# ############################################################################## Generate sample dataX = 5 * rng.rand(10000, 1)y = np.sin(X).ravel()# Add noise to targetsy[::5] += 3 * (0.5 - rng.rand(X.shape[0] // 5))X_plot = np.linspace(0, 5, 100000)[:, None]# ############################################################################## Fit regression modeltrain_size = 100svr = GridSearchCV(SVR(kernel='rbf', gamma=0.1), cv=5,                   param_grid={"C": [1e0, 1e1, 1e2, 1e3],                               "gamma": np.logspace(-2, 2, 5)})kr = GridSearchCV(KernelRidge(kernel='rbf', gamma=0.1), cv=5,                  param_grid={"alpha": [1e0, 0.1, 1e-2, 1e-3],                              "gamma": np.logspace(-2, 2, 5)})t0 = time.time()svr.fit(X[:train_size], y[:train_size])svr_fit = time.time() - t0print("SVR complexity and bandwidth selected and model fitted in %.3f s"      % svr_fit)t0 = time.time()kr.fit(X[:train_size], y[:train_size])kr_fit = time.time() - t0print("KRR complexity and bandwidth selected and model fitted in %.3f s"      % kr_fit)sv_ratio = svr.best_estimator_.support_.shape[0] / train_sizeprint("Support vector ratio: %.3f" % sv_ratio)t0 = time.time()y_svr = svr.predict(X_plot)svr_predict = time.time() - t0print("SVR prediction for %d inputs in %.3f s"      % (X_plot.shape[0], svr_predict))t0 = time.time()y_kr = kr.predict(X_plot)kr_predict = time.time() - t0print("KRR prediction for %d inputs in %.3f s"      % (X_plot.shape[0], kr_predict))# ############################################################################## Look at the resultssv_ind = svr.best_estimator_.support_plt.scatter(X[sv_ind], y[sv_ind], c='r', s=50, label='SVR support vectors',            zorder=2, edgecolors=(0, 0, 0))plt.scatter(X[:100], y[:100], c='k', label='data', zorder=1,            edgecolors=(0, 0, 0))plt.plot(X_plot, y_svr, c='r',         label='SVR (fit: %.3fs, predict: %.3fs)' % (svr_fit, svr_predict))plt.plot(X_plot, y_kr, c='g',         label='KRR (fit: %.3fs, predict: %.3fs)' % (kr_fit, kr_predict))plt.xlabel('data')plt.ylabel('target')plt.title('SVR versus Kernel Ridge')plt.legend()# Visualize training and prediction timeplt.figure()# Generate sample dataX = 5 * rng.rand(10000, 1)y = np.sin(X).ravel()y[::5] += 3 * (0.5 - rng.rand(X.shape[0] // 5))sizes = np.logspace(1, 4, 7, dtype=np.int)for name, estimator in {"KRR": KernelRidge(kernel='rbf', alpha=0.1,                                           gamma=10),                        "SVR": SVR(kernel='rbf', C=1e1, gamma=10)}.items():    train_time = []    test_time = []    for train_test_size in sizes:        t0 = time.time()        estimator.fit(X[:train_test_size], y[:train_test_size])        train_time.append(time.time() - t0)        t0 = time.time()        estimator.predict(X_plot[:1000])        test_time.append(time.time() - t0)    plt.plot(sizes, train_time, 'o-', color="r" if name == "SVR" else "g",             label="%s (train)" % name)    plt.plot(sizes, test_time, 'o--', color="r" if name == "SVR" else "g",             label="%s (test)" % name)plt.xscale("log")plt.yscale("log")plt.xlabel("Train size")plt.ylabel("Time (seconds)")plt.title('Execution Time')plt.legend(loc="best")# Visualize learning curvesplt.figure()svr = SVR(kernel='rbf', C=1e1, gamma=0.1)kr = KernelRidge(kernel='rbf', alpha=0.1, gamma=0.1)train_sizes, train_scores_svr, test_scores_svr = \    learning_curve(svr, X[:100], y[:100], train_sizes=np.linspace(0.1, 1, 10),                   scoring="neg_mean_squared_error", cv=10)train_sizes_abs, train_scores_kr, test_scores_kr = \    learning_curve(kr, X[:100], y[:100], train_sizes=np.linspace(0.1, 1, 10),                   scoring="neg_mean_squared_error", cv=10)plt.plot(train_sizes, -test_scores_svr.mean(1), 'o-', color="r",         label="SVR")plt.plot(train_sizes, -test_scores_kr.mean(1), 'o-', color="g",         label="KRR")plt.xlabel("Train size")plt.ylabel("Mean Squared Error")plt.title('Learning curves')plt.legend(loc="best")plt.show()
原创粉丝点击