机器学习-CrossValidation交叉验证

来源:互联网 发布:安卓网络机顶盒论坛 编辑:程序博客网 时间:2024/04/29 12:13

概念

“交叉验证法”(cross validation)是一种很好并准确的用于评估模型的方法。它先将数据集D划分为k个大小相似的互斥子集,即D=D1D2...Dk,DiDj=ij。每个子集Di都尽可能保持数据分布的一致性,即,从D中通过分层采样得到。然后,每次用k1个子集的并集作为训练集,余下的那个子集作为测试集,这样,就可以获得k组训练/测试集,从而可进行k次训练和测试,最终返回的是这k个测试结果的均值。交叉验证通常称为“k折交叉验证”,k一般取10。

  • 优点:K-CV可以有效的避免过学习以及欠学习状态的发生,最后得到的结果也比较具有说服性.
  • 缺点:K值的选取很重要

python实现

from sklearn import cross_validationfrom sklearn.model_selection import train_test_splitfrom sklearn.ensemble.gradient_boosting import GradientBoostingClassifier # 训练/测试数据分割X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=0.3, random_state=42)# 定义GBDT模型gbdt = GradientBoostingClassifier(init=None, learning_rate=0.05, loss='deviance',              max_depth=5, max_features=None, max_leaf_nodes=None,              min_samples_leaf=1, min_samples_split=2,              min_weight_fraction_leaf=0.0, n_estimators=500,              random_state=None, subsample=1.0, verbose=0,              warm_start=False)# 训练学习gbdt.fit(X_train, y_train)importances = gbdt.feature_importances_# 预测及AUC评测y_pred_gbdt = gbdt.predict_proba(X_test.toarray())[:, 1]gbdt_auc = roc_auc_score(y_test, y_pred_gbdt)print('The AUC of GBDT: %.5f' % gbdt_auc)# cross_validationprint('--------------------------------cross_validation----------------------------')score = cross_validation.cross_val_score(gbdt_auc, X_all, y_all, cv=5, scoring='roc_auc')sum = 0for sc in score:    sum += scprint('GBDT 平均AUC:')print(sum / score.shape)print('交叉验证各维AUC:') print(score)这里以gbdt模型为例X_all:训练集y_all:标签cv: 交叉验证的次数scoring: 评估指标,可以自定义,也有很多默认选项         例如‘accuracy’, 就是返回准确率         [‘accuracy‘, ‘adjusted_rand_score‘,          ‘average_precision‘, ‘f1‘, ‘f1_macro‘,          ‘f1_micro‘, ‘f1_samples‘, ‘f1_weighted‘,          ‘log_loss‘, ‘mean_absolute_error‘,          ‘mean_squared_error‘, ‘median_absolute_error‘,          ‘precision‘, ‘precision_macro‘,           ‘precision_micro‘, ‘precision_samples‘,           ‘precision_weighted‘, ‘r2‘, ‘recall‘,           ‘recall_macro‘, ‘recall_micro‘,           ‘recall_samples‘, ‘recall_weighted‘,           ‘roc_auc‘]  
原创粉丝点击