sklearn.utils.shuffle解析

来源:互联网 发布:php判断水仙花数 编辑:程序博客网 时间:2024/05/24 07:10

在进行机器学习时,经常需要打乱样本,这种时候Python中叒有第三方库提供了这个功能——sklearn.utils.shuffle。

Shuffle arrays or sparse matrices in a consistent way. This is a convenience alias to resample(*arrays, replace=False) to do random permutations of the collections.

函数参数

Parameters

参数 介绍 *array 带索引的序列,可以是arrays, lists, dataframes或scipy sparse matrices random_state int,随机量,就是一个random seed。如果是int,该参数作为random seed的值;如果是None,随机生成器就是一个np.random实例 n_sample int,默认为None,输出的样本数目。如果是空,则样本数目会设置为array的第一维元素数

Returns

参数 介绍 shuffled_arrays 带索引的序列,是一个view(也就是说不会改变输入array)

Examples

这里写图片描述

解释:例程中建立了3个带索引的序列:array, array和sparse matrix。然后将它们作为一个元组进行shuffle,其中random_state=0表示它们的打乱方式是方式0。这个打乱方式不理解的可以看一下np.random.seed的介绍或者是看我接下来对源码的解析。

源码

shuffle

def shuffle(*arrays, **options):    options['replace'] = False    return resample(*arrays, **options)

Are you kidding? 这是个“空壳函数”。唯一的作用就是将一个参数replace置为了False,好让shuffle过程中不影响输入array(不过要记住这个replace,这是sklearn.utils.shufflesklearn.utils.resample唯一的区别)。

那么下面来看resample函数。

resample

def resample(*arrays, **options):    '''先是类型检测部分,可以跳过'''    random_state = check_random_state(options.pop('random_state', None))  # 此处注意:返回类型变了,变成:np.random.mtrand._rand或np.random.RandomState(seed)或seed    replace = option.pop('replace', True)  # 如果没有‘replace’则返回True    max_n_samples = options.pop('n_samples', None)    if options:        raise ValueError("Unexpected kw arguments: %r" % options.keys())    if len(arrays) == 0:    return None    first = arrays[0]    n_samples = first.shape[0] if hasattr(first, 'shape') else len(first)    if max_n_samples is None:        max_n_samples = n_samples    elif (max_n_samples > n_samples) and (not replace):        raise ValueError("Cannot sample %d out of arrays with dim %d when replace is False" % (max_n_samples, n_samples))    check_consistent_length(*array)    '''开始正文'''    '''重排索引'''    if replace:        indices = random_state.randint(0, n_samples, size=(max_n_samples,))  # 创建新的随机序列索引indices    else:        indices = np.arange(n_samples)        random_state.shuffle(indices)        indices = indices[:max_n_samples]    # convert sparse matrices to CSR for row-based indexing    arrays = [a.tocsr() if issparse(a) else a for a in arrays]    '''根据indices对arrays进行采样'''    resampled_arrays = [safe_indexing(a, indices) for a in arrays]    '''分两种情况,一种是输入的*arrays参数只有一个序列,另一种是输入的*arrays参数是一元组的序列'''    if len(resampled_arrays) == 1:        # syntactic sugar for the unit argument case        return resampled_arrays[0]    else:        return resampled_arrays

解释:代码分为三部分:

  1. 类型检测
  2. 构建重排后的索引
  3. 根据索引输出序列

random_state在函数中用于产生随机索引

原创粉丝点击