scikit-learn中gridSearchCV 的使用
来源:互联网 发布:超图软件工资怎么样 编辑:程序博客网 时间:2024/06/02 01:58
GridSearchCV使用介绍
通常算法不够好,需要调试参数时必不可少。比如SVM的惩罚因子C,核函数kernel,gamma参数等,对于不同的数据使用不同的参数,结果效果可能差1-5个点,sklearn为我们提供专门调试参数的函数grid_search。
函数介绍
class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, verbose=0, pre_dispatch=‘2*n_jobs’, error_score=’raise’, return_train_score=’warn’)
参数介绍:
- estimator —— 模型
- param_grid —— dict or list of dictionaries
- scoring : 评分函数
- fit_params : dict, optional
- n_jobs : 并行任务个数,int, default=1
- pre_dispatch : int, or string, optional ‘2*n_jobs’
- iid : boolean, default=True
- cv : int, 交叉验证,默认3
- refit : boolean, or string, default=True
- verbose : integer
- error_score : ‘raise’ (default) or numeric
下面是官网的一个使用案例
>>> from sklearn import svm, datasets>>> from sklearn.model_selection import GridSearchCV>>> iris = datasets.load_iris()>>> parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}>>> svc = svm.SVC()>>> clf = GridSearchCV(svc, parameters)>>> clf.fit(iris.data, iris.target)... GridSearchCV(cv=None, error_score=..., estimator=SVC(C=1.0, cache_size=..., class_weight=..., coef0=..., decision_function_shape='ovr', degree=..., gamma=..., kernel='rbf', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=..., verbose=False), fit_params=None, iid=..., n_jobs=1, param_grid=..., pre_dispatch=..., refit=..., return_train_score=..., scoring=..., verbose=...)>>> sorted(clf.cv_results_.keys())... ['mean_fit_time', 'mean_score_time', 'mean_test_score',... 'mean_train_score', 'param_C', 'param_kernel', 'params',... 'rank_test_score', 'split0_test_score',... 'split0_train_score', 'split1_test_score', 'split1_train_score',... 'split2_test_score', 'split2_train_score',... 'std_fit_time', 'std_score_time', 'std_test_score', 'std_train_score'...]
实例介绍
下面通过官网关于svr的简单的应用介绍一下在svr中的应用
官网
# Authors: Jan Hendrik Metzen <jhm@informatik.uni-bremen.de># License: BSD 3 clausefrom __future__ import divisionimport timeimport numpy as npfrom sklearn.svm import SVRfrom sklearn.model_selection import GridSearchCVfrom sklearn.model_selection import learning_curvefrom sklearn.kernel_ridge import KernelRidgeimport matplotlib.pyplot as pltrng = np.random.RandomState(0)# ############################################################################## Generate sample dataX = 5 * rng.rand(10000, 1)y = np.sin(X).ravel()# Add noise to targetsy[::5] += 3 * (0.5 - rng.rand(X.shape[0] // 5))X_plot = np.linspace(0, 5, 100000)[:, None]# ############################################################################## Fit regression modeltrain_size = 100svr = GridSearchCV(SVR(kernel='rbf', gamma=0.1), cv=5, param_grid={"C": [1e0, 1e1, 1e2, 1e3], "gamma": np.logspace(-2, 2, 5)})kr = GridSearchCV(KernelRidge(kernel='rbf', gamma=0.1), cv=5, param_grid={"alpha": [1e0, 0.1, 1e-2, 1e-3], "gamma": np.logspace(-2, 2, 5)})t0 = time.time()svr.fit(X[:train_size], y[:train_size])svr_fit = time.time() - t0print("SVR complexity and bandwidth selected and model fitted in %.3f s" % svr_fit)t0 = time.time()kr.fit(X[:train_size], y[:train_size])kr_fit = time.time() - t0print("KRR complexity and bandwidth selected and model fitted in %.3f s" % kr_fit)sv_ratio = svr.best_estimator_.support_.shape[0] / train_sizeprint("Support vector ratio: %.3f" % sv_ratio)t0 = time.time()y_svr = svr.predict(X_plot)svr_predict = time.time() - t0print("SVR prediction for %d inputs in %.3f s" % (X_plot.shape[0], svr_predict))t0 = time.time()y_kr = kr.predict(X_plot)kr_predict = time.time() - t0print("KRR prediction for %d inputs in %.3f s" % (X_plot.shape[0], kr_predict))# ############################################################################## Look at the resultssv_ind = svr.best_estimator_.support_plt.scatter(X[sv_ind], y[sv_ind], c='r', s=50, label='SVR support vectors', zorder=2, edgecolors=(0, 0, 0))plt.scatter(X[:100], y[:100], c='k', label='data', zorder=1, edgecolors=(0, 0, 0))plt.plot(X_plot, y_svr, c='r', label='SVR (fit: %.3fs, predict: %.3fs)' % (svr_fit, svr_predict))plt.plot(X_plot, y_kr, c='g', label='KRR (fit: %.3fs, predict: %.3fs)' % (kr_fit, kr_predict))plt.xlabel('data')plt.ylabel('target')plt.title('SVR versus Kernel Ridge')plt.legend()# Visualize training and prediction timeplt.figure()# Generate sample dataX = 5 * rng.rand(10000, 1)y = np.sin(X).ravel()y[::5] += 3 * (0.5 - rng.rand(X.shape[0] // 5))sizes = np.logspace(1, 4, 7, dtype=np.int)for name, estimator in {"KRR": KernelRidge(kernel='rbf', alpha=0.1, gamma=10), "SVR": SVR(kernel='rbf', C=1e1, gamma=10)}.items(): train_time = [] test_time = [] for train_test_size in sizes: t0 = time.time() estimator.fit(X[:train_test_size], y[:train_test_size]) train_time.append(time.time() - t0) t0 = time.time() estimator.predict(X_plot[:1000]) test_time.append(time.time() - t0) plt.plot(sizes, train_time, 'o-', color="r" if name == "SVR" else "g", label="%s (train)" % name) plt.plot(sizes, test_time, 'o--', color="r" if name == "SVR" else "g", label="%s (test)" % name)plt.xscale("log")plt.yscale("log")plt.xlabel("Train size")plt.ylabel("Time (seconds)")plt.title('Execution Time')plt.legend(loc="best")# Visualize learning curvesplt.figure()svr = SVR(kernel='rbf', C=1e1, gamma=0.1)kr = KernelRidge(kernel='rbf', alpha=0.1, gamma=0.1)train_sizes, train_scores_svr, test_scores_svr = \ learning_curve(svr, X[:100], y[:100], train_sizes=np.linspace(0.1, 1, 10), scoring="neg_mean_squared_error", cv=10)train_sizes_abs, train_scores_kr, test_scores_kr = \ learning_curve(kr, X[:100], y[:100], train_sizes=np.linspace(0.1, 1, 10), scoring="neg_mean_squared_error", cv=10)plt.plot(train_sizes, -test_scores_svr.mean(1), 'o-', color="r", label="SVR")plt.plot(train_sizes, -test_scores_kr.mean(1), 'o-', color="g", label="KRR")plt.xlabel("Train size")plt.ylabel("Mean Squared Error")plt.title('Learning curves')plt.legend(loc="best")plt.show()
阅读全文
0 0
- scikit-learn中gridSearchCV 的使用
- scikit-learn/KNN算法使用GridSearchCV调优
- scikit-learn/ID3算法使用GridSearchCV调优
- scikit -learn 的使用
- Scikit-Learn的基本使用
- 在Celery中使用scikit-learn
- scikit-learn中随机森林使用详解
- scikit-learn中PCA的使用方法
- scikit-learn中PCA的使用方法
- scikit-learn中PCA的使用方法
- scikit-learn中PCA的使用方法
- scikit-learn中PCA的使用方法
- scikit-learn中算法的调用
- scikit learn 中pca 的用法
- scikit-learn中PCA的使用方法
- scikit-learn中PCA的使用方法
- scikit-learn中PCA的使用方法
- scikit-learn使用的简易说明
- 数据库建表方法
- 【Python编程:从入门到实践】第十章:文件和异常
- PAT (Basic Level) Practise (中文)1055. 集体照 (25)
- 一个看似简单的数字交换问题
- POJ1912_A highway and the seven dwarfs_判断凸包与直线是否相交
- scikit-learn中gridSearchCV 的使用
- 博弈论(1):囚徒困境中的博弈论
- java常用排序算法
- spring in action笔记(四)——高级装配
- if there is a error "SSSE3 instruction set not enabled"
- Python利用Scrapy爬取智联招聘和前程无忧的招聘数据
- GDB分析ELF文件常用的调试技巧
- android 编写发表帖子的页面
- 线上环境安装配置实操(jdk-tomcat-maven-vsftpd-nginx)