sklearn学习笔记-《超参数优化方法》

来源:互联网 发布:sql中len的用法 编辑:程序博客网 时间:2024/06/05 16:29

超参数:

学习器模型中一般有两种参数,一种参数是可以从学习中得到,还有一种无法靠数据里面得到,只能靠人的经验来设定,这类参数就叫做超参数。

优化超参数:

参数空间是由
1.一个回归器或者一个分类器
2.一个参数空间
3.一个搜索或者采样机制来获得候选你参数
4.一个交叉验证机制
5.一个评分函数

有两种优化超参数的方法
1.网格搜索(GridSearchCV)
GridSearchCV,它存在的意义就是自动调参,只要把参数输进去,就能给出最优化的结果和参数。但是这个方法适合于小数据集,一旦数据的量级上去了,很难得出结果。这个时候就是需要动脑筋了。数据量比较大的时候可以使用一个快速调优的方法——坐标下降它其实是一种贪心算法:拿当前对模型影响最大的参数调优,直到最优化再拿下一个影响最大的参数调优,如此下去,直到所有的参数调整完毕。这个方法的缺点就是可能会调到局部最优而不是全局最优,但是省时间省力,巨大的优势面前,还是试一试吧,后续可以再拿bagging再优化。
#coding:utf-8#超参数优化方法# from sklearn import svm, datasets# from sklearn.model_selection import GridSearchCV# iris = datasets.load_iris()# svr = svm.SVC()#分类器# paramters = {'kernel': ('rbf', 'linear'), 'C': (1, 5, 10)}#核函数,取值,穷举# clf = GridSearchCV(svr, paramters)# clf.fit(iris.data, iris.target)# print (clf.best_estimator_)#最优参数组合
2.随机采样(RandomizedSearchCV)
1、数据规模大,精确的结果难以在一定时间计算出。
2、结果的些许的不精确能够被接受。
3、求取的结果是最优化(optimization)问题,有一个成本计算模型。
import numpy as npfrom time import timefrom scipy.stats import randint as sp_randintfrom sklearn.model_selection import RandomizedSearchCV, GridSearchCV#导入from sklearn.datasets import load_digitsfrom sklearn.ensemble import RandomForestClassifier#用于报告超参数搜索的最好结果的函数def report(results, n_top=3):#从每次交叉验证中的历史信息中找到最好的三个    for i in range(1, n_top + 1):        candidates = np.flatnonzero(results['rank_test_score'] == 1)#flatnozero返回不为零的索引,==1指的是值最大的那个        for candidate in candidates:            print ("model with rank : {0}".format(i))#结果中的排名            print ("mean validation score:{0:.3f} (std: {1:.3f})".format(                results['mean_test_score'][candidate],#平均                results['std_test_score'][candidate]#标准差            ))            print ("Parameters: {0}".format(results['params'][candidate]))#参数最终的优化的值            print ("")digits = load_digits()x, y = digits.data, digits.targetclf = RandomForestClassifier(n_estimators=20)#随机森林分类param_dist = {"max_depth": [3, None],              "criterion": ["gini", "entropy"],              "min_samples_split": sp_randint(2, 11),              "min_samples_leaf": sp_randint(1, 11),              "max_features": sp_randint(1, 11),              "bootstrap": [True, False],}#参数字典n_iter_search = 20#迭代次数random_search = RandomizedSearchCV(clf, param_distributions=param_dist, n_iter=n_iter_search)#随机采样搜索start = time()random_search.fit(x, y)#拟合print ("random took %.2f seconds for %d candidates parameter settings"       % (time()-start, len(random_search.cv_results_['params'])))#花费时间和多少个候选参数# print random_search.cv_results_report(random_search.cv_results_)# param_dist1 = {"max_depth": [3, None],#                "criterion": ["gini", "entropy"],#                "min_samples_split": [2, 3, 10],#                "min_samples_leaf": [1, 3, 10],#                "max_features": [1, 3, 10],#                "bootstrap": [True, False], }# grid_research = GridSearchCV(clf, param_grid=param_dist1)# grid_research.fit(x, y)# print ("Grid took %.2f seconds for %d candidates parameter settings"#        % (time()-start, len(grid_research.cv_results_['params'])))# report(grid_research.cv_results_)





原创粉丝点击