scikit-learn交叉验证Cross Validation and Grid Search
来源:互联网 发布:巨人网络 借壳猜想 编辑:程序博客网 时间:2024/05/21 22:25
引自:http://hisarack.logdown.com/posts/304896-scikit-learn%E5%AF%A6%E4%BD%9C-cross-validation-and-grid-search
實際上在跑驗證時,我們會希望train資料集越大越好,讓model可以學習到更正 確的參數.但是train資料集越大同樣表示我們可以用來做驗證的資料(testing)越少.
Leave-One-Out提供了一個想法,假如我們有10筆資料,第一次做驗證時把第一筆留下來做驗證,剩下編號2~10的資料拿去做驗證.第二次留第二筆作驗證,依此類推,共做10次驗證再把這10次結果平均
這邊再使用文字辨識的例子並且產生LeaveOneOut的物件,產生LeaveOneOut物件必須先告訴它整個資料集的筆數
>>> from sklearn import datasets>>> digits = datasets.load_digits()>>> print(digits.data.shape)(1797, 64)>>> from sklearn import cross_validation>>> loo = cross_validation.LeaveOneOut(1797)
model選擇Linear SVC演算法試看看!跑了1797次驗證後把分數平均,猜對是1分,猜錯是0分.可得Linear SVC正確率達97%
>>> from sklearn import svm>>> svc = svm.SVC(C=1,kernel='linear')>>> print(svc)SVC(C=1, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,kernel='linear', max_iter=-1, probability=False, random_state=None,shrinking=True, tol=0.001, verbose=False)>>> import numpy as np>>> np.mean([ svc.fit(X_digits[train],Y_digits[train]).score(X_digits[test],Y_digits[test]) for train,test in loo ])0.97996661101836391
LeaveOneOut求出的正確率會與model上線運轉後的正確率較為接近,但是資料筆數若是很多則會相當耗時.退而求其次的方法是,我們把資料切成多塊 然後使用LeaveOneOut的想法,把每一塊做ㄧ次驗證,該次驗證則使用剩下的資料去train model.假若切10塊,就只要跑10次驗證.這種方法稱之為K-fold
>>> k_fold = cross_validation.KFold(n=1797,n_folds=10)>>> n k_fold ])0.96382681564245798
除了LeaveOneOut與K-fold之外,Scikit-Learn還提供Stratified K-fold,由於資料集是有標上類別(target)的,若是類別數量少,K-fold可以在切塊時可以把每個類別切成K塊,再從每個類別中取出一塊合起來當驗證用.LeaveOneLabelOut則是把資料標上Label,並把相同Label的資料集丟出去當驗證,注意這邊的Label不可以是類別(target).除了這兩種還有LeavePOut,LeavePLabelOut.以下為Stratified K-fold範例
>>> skf = cross_validation.StratifiedKFold(digits.target,10)>>> print(skf)sklearn.cross_validation.StratifiedKFold(labels=[0 1 2 ..., 8 9 8],n_folds=10, shuffle=False, random_state=None)>>> np.mean([ svc.fit(X_digits[train],Y_digits[train]).score(X_digits[test],Y_digits[test]) for train,test in skf ])0.96108002488977162
況且實際上要跑的model都不只有一個,即使使用相同的演算法,要測試的參數組合也不會只有一組.GridSearch則為一個方便的工具,幫我們從給定的參數範圍 中找出最好的model,以下範例是測試Regularization參數C從10^-6到10^-1, 注意這邊GridSearchCV的n_jobs=-1是表示可以使用所有的CPU做平行化運算
>>> from sklearn.grid_search import GridSearchCV>>> Cs = np.logspace(-6, -1, 10)>>> clf = GridSearchCV(estimator=svc,param_grid=dict(C=Cs),n_jobs=-1)>>> clf.fit(X_digits[:1000], Y_digits[:1000])GridSearchCV(cv=None, error_score='raise',estimator=SVC(C=1, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,kernel='linear', max_iter=-1, probability=False,random_state=None,shrinking=True, tol=0.001, verbose=False),fit_params={}, iid=True, loss_func=None, n_jobs=-1,param_grid={'C': array([ 1.00000e-06, 3.59381e-06, 1.29155e-05,4.64159e-05,1.66810e-04, 5.99484e-04, 2.15443e-03, 7.74264e-03,2.78256e-02, 1.00000e-01])},pre_dispatch='2*n_jobs', refit=True, score_func=None, scoring=None,verbose=0)>>> clf.best_score_0.9250000000000000>>> clf.best_estimator_SVC(C=0.0077426368268112772, cache_size=200, class_weight=None, coef0=0.0,degree=3, gamma=0.0, kernel='linear', max_iter=-1, probability=False,random_state=None, shrinking=True, tol=0.001, verbose=False)
最後最後797筆資料模擬線上運行後的正確率
>>> clf.score(X_digits[1000:], Y_digits[1000:])0.94353826850690092
- scikit-learn交叉验证Cross Validation and Grid Search
- 【scikit-learn】05:交叉验证 Cross-validation
- cross-validation 交叉验证
- 交叉验证--Cross validation
- cross-validation 交叉验证
- 交叉验证(Cross Validation)
- cross validation交叉验证
- cross validation 交叉验证
- 交叉验证(Cross-validation)
- 交叉验证(Cross-Validation)
- 交叉验证(Cross-Validation)
- 交叉验证(Cross-Validation)
- 交叉验证(Cross-validation)
- 交叉验证(Cross validation)
- 交叉验证(Cross-validation)
- 交叉验证 Cross-validation
- 【机器学习】交叉验证和K-折交叉验证cross-validation and k-fold cross-validation
- 交叉验证(cross-validation)
- sql去掉查询某个列相同的记录数
- logisitic 回归 +极大似然法 + 梯度下降法 (迭代优化)
- EventBus源码阅读(9)-SubscriberInfoIndex
- 不可错过的一些精彩的android 组件view
- WampServer下修改和重置MySQL密码
- scikit-learn交叉验证Cross Validation and Grid Search
- javascript中的可变参数
- android ImageView 显示本地图片
- 添加背景音乐
- poj2253Frogger(dijkstra求最小边权问题)
- Exception in thread "main" java.lang.Error: Unresolved compilation problem:
- socket的API函数总结
- [总结]视音频编解码技术零基础学习方法
- python 获取网络图片并下载到本地(由网络源码改编)