sklearn.model_selection.GroupKFold

来源:互联网 发布:直男癌的九大特征知乎 编辑:程序博客网 时间:2024/06/09 23:58

分组K折交叉验证:sklearn.model_selection.GroupKFold(n_splits=3)

参数说明:

n_splits:折数,默认为3,至少为2

注意点:同一组的样本不可能同时出现在同一折的测试集和训练集中。

①数据集均等份,且每组中的样本也是均等

In [11]: from sklearn.model_selection import GroupKFold    ...: import numpy as np    ...: X = np.arange(24).reshape(12,2)    ...: y = np.array([1,1,2,3,1,2,3,2,2,3,3,1])    ...: groups = np.array([1,2,3,4,5,6,1,2,3,4,5,6])    ...: kf = GroupKFold(n_splits=6)    ...: for train_index , test_index in kf.split(X,y,groups):    ...:     print('train_index:%s , test_index: %s ' %(train_index,test_index)    ...: )    ...:     print('train_groups:%s , test_groups: %s ' %(groups[train_index],g    ...: roups[test_index]))    ...:train_index:[ 0  1  2  3  4  6  7  8  9 10] , test_index: [ 5 11]train_groups:[1 2 3 4 5 1 2 3 4 5] , test_groups: [6 6]train_index:[ 0  1  2  3  5  6  7  8  9 11] , test_index: [ 4 10]train_groups:[1 2 3 4 6 1 2 3 4 6] , test_groups: [5 5]train_index:[ 0  1  2  4  5  6  7  8 10 11] , test_index: [3 9]train_groups:[1 2 3 5 6 1 2 3 5 6] , test_groups: [4 4]train_index:[ 0  1  3  4  5  6  7  9 10 11] , test_index: [2 8]train_groups:[1 2 4 5 6 1 2 4 5 6] , test_groups: [3 3]train_index:[ 0  2  3  4  5  6  8  9 10 11] , test_index: [1 7]train_groups:[1 3 4 5 6 1 3 4 5 6] , test_groups: [2 2]train_index:[ 1  2  3  4  5  7  8  9 10 11] , test_index: [0 6]train_groups:[2 3 4 5 6 2 3 4 5 6] , test_groups: [1 1]
②折数均等,组数不平衡

In [13]: #sklearn.model_selection.GroupKFold(n_splits=3)    ...: from sklearn.model_selection import GroupKFold    ...: import numpy as np    ...: X = np.arange(24).reshape(12,2)    ...: y = np.array([1,1,2,3,1,2,3,2,2,3,3,1])    ...: groups = np.array([1,2,3,4,5,6,1,2,3,4,5,7])    ...: kf = GroupKFold(n_splits=4)    ...: for train_index , test_index in kf.split(X,y,groups):    ...:     print('train_index:%s , test_index: %s ' %(train_index,test_index)    ...: )    ...:     print('train_groups:%s , test_groups: %s ' %(groups[train_index],g    ...: roups[test_index]))    ...:train_index:[ 1  2  3  5  7  8  9 11] , test_index: [ 0  4  6 10]train_groups:[2 3 4 6 2 3 4 7] , test_groups: [1 5 1 5]train_index:[ 0  1  2  4  5  6  7  8 10] , test_index: [ 3  9 11]train_groups:[1 2 3 5 6 1 2 3 5] , test_groups: [4 4 7]train_index:[ 0  1  3  4  6  7  9 10 11] , test_index: [2 5 8]train_groups:[1 2 4 5 1 2 4 5 7] , test_groups: [3 6 3]train_index:[ 0  2  3  4  5  6  8  9 10 11] , test_index: [1 7]train_groups:[1 3 4 5 6 1 3 4 5 7] , test_groups: [2 2]
③折数不均等,组数不平衡

In [14]: from sklearn.model_selection import GroupKFold    ...: import numpy as np    ...: X = np.arange(24).reshape(12,2)    ...: y = np.array([1,1,2,3,1,2,3,2,2,3,3,1])    ...: groups = np.array([1,2,3,4,5,6,1,2,3,4,5,3])    ...: kf = GroupKFold(n_splits=5)    ...: for train_index , test_index in kf.split(X,y,groups):    ...:     print('train_index:%s , test_index: %s ' %(train_index,test_index)    ...: )    ...:     print('train_groups:%s , test_groups: %s ' %(groups[train_index],g    ...: roups[test_index]))    ...:train_index:[ 0  1  3  4  5  6  7  9 10] , test_index: [ 2  8 11]train_groups:[1 2 4 5 6 1 2 4 5] , test_groups: [3 3 3]train_index:[ 0  1  2  3  6  7  8  9 11] , test_index: [ 4  5 10]train_groups:[1 2 3 4 1 2 3 4 3] , test_groups: [5 6 5]train_index:[ 0  1  2  4  5  6  7  8 10 11] , test_index: [3 9]train_groups:[1 2 3 5 6 1 2 3 5 3] , test_groups: [4 4]train_index:[ 0  2  3  4  5  6  8  9 10 11] , test_index: [1 7]train_groups:[1 3 4 5 6 1 3 4 5 3] , test_groups: [2 2]train_index:[ 1  2  3  4  5  7  8  9 10 11] , test_index: [0 6]train_groups:[2 3 4 5 6 2 3 4 5 3] , test_groups: [1 1]



阅读全文
0 0
原创粉丝点击