scikit-learn源码学习之cluster.mean_shift.estimate_bandwidth

来源:互联网 发布:深圳淘宝托管 编辑:程序博客网 时间:2024/06/05 22:33

继续我的源码学习之旅,这次是mean-shift聚类算法里面的estimate_bandwidth函数。
estimate_bandwidth函数用作于mean-shift算法估计带宽,如果MeanShift函数没有传入bandwidth参数,MeanShift会自动运行estimate_bandwidth,源码地址

def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0,                       n_jobs=1):    """Estimate the bandwidth to use with the mean-shift algorithm.    That this function takes time at least quadratic in n_samples. For large    datasets, it's wise to set that parameter to a small value.    Parameters    ----------    X : array-like, shape=[n_samples, n_features]        Input points.    quantile : float, default 0.3        should be between [0, 1]        0.5 means that the median of all pairwise distances is used.    n_samples : int, optional        The number of samples to use. If not given, all samples are used.    random_state : int or RandomState        Pseudo-random number generator state used for random sampling.    n_jobs : int, optional (default = 1)        The number of parallel jobs to run for neighbors search.        If ``-1``, then the number of jobs is set to the number of CPU cores.    Returns    -------    bandwidth : float        The bandwidth parameter.    """    #根据random_state生成伪随机数生成器    random_state = check_random_state(random_state)    if n_samples is not None:        #permutation将序列打乱 并取n_samples个数的样本        idx = random_state.permutation(X.shape[0])[:n_samples]        X = X[idx]    #非监督方式进行近邻搜索    #quantile的值表示进行近邻搜索时候的近邻占样本的比例    nbrs = NearestNeighbors(n_neighbors=int(X.shape[0] * quantile),                            n_jobs=n_jobs)    nbrs.fit(X)    bandwidth = 0.    #gen_batches(n,batch_size) 根据batch_size的大小生成0~n的切片    for batch in gen_batches(len(X), 500):        #kneighbors返回batch里面每个点的n_sample个邻居的距离(不包括自己)        #n_sample要是没有定义那就和NearestNeighbors里面的n_neighbors相等        #还有个返回值是下标,不过用不到就拿_忽略了        d, _ = nbrs.kneighbors(X[batch, :], return_distance=True)        #将每个点的最近的n_neighbors个邻居中最远的距离加起来        bandwidth += np.max(d, axis=1).sum()    #本质上就是求平均最远k近邻距离    return bandwidth / X.shape[0]

中文注释都是个人见解,如果有写的不到位的地方,欢迎大家评论区拍砖

1 0