scikit-learn源码学习之cluster.mean_shift.estimate_bandwidth
来源:互联网 发布:深圳淘宝托管 编辑:程序博客网 时间:2024/06/05 22:33
继续我的源码学习之旅,这次是mean-shift聚类算法里面的estimate_bandwidth函数。
estimate_bandwidth函数用作于mean-shift算法估计带宽,如果MeanShift函数没有传入bandwidth参数,MeanShift会自动运行estimate_bandwidth,源码地址
def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0, n_jobs=1): """Estimate the bandwidth to use with the mean-shift algorithm. That this function takes time at least quadratic in n_samples. For large datasets, it's wise to set that parameter to a small value. Parameters ---------- X : array-like, shape=[n_samples, n_features] Input points. quantile : float, default 0.3 should be between [0, 1] 0.5 means that the median of all pairwise distances is used. n_samples : int, optional The number of samples to use. If not given, all samples are used. random_state : int or RandomState Pseudo-random number generator state used for random sampling. n_jobs : int, optional (default = 1) The number of parallel jobs to run for neighbors search. If ``-1``, then the number of jobs is set to the number of CPU cores. Returns ------- bandwidth : float The bandwidth parameter. """ #根据random_state生成伪随机数生成器 random_state = check_random_state(random_state) if n_samples is not None: #permutation将序列打乱 并取n_samples个数的样本 idx = random_state.permutation(X.shape[0])[:n_samples] X = X[idx] #非监督方式进行近邻搜索 #quantile的值表示进行近邻搜索时候的近邻占样本的比例 nbrs = NearestNeighbors(n_neighbors=int(X.shape[0] * quantile), n_jobs=n_jobs) nbrs.fit(X) bandwidth = 0. #gen_batches(n,batch_size) 根据batch_size的大小生成0~n的切片 for batch in gen_batches(len(X), 500): #kneighbors返回batch里面每个点的n_sample个邻居的距离(不包括自己) #n_sample要是没有定义那就和NearestNeighbors里面的n_neighbors相等 #还有个返回值是下标,不过用不到就拿_忽略了 d, _ = nbrs.kneighbors(X[batch, :], return_distance=True) #将每个点的最近的n_neighbors个邻居中最远的距离加起来 bandwidth += np.max(d, axis=1).sum() #本质上就是求平均最远k近邻距离 return bandwidth / X.shape[0]
中文注释都是个人见解,如果有写的不到位的地方,欢迎大家评论区拍砖
1 0
- scikit-learn源码学习之cluster.mean_shift.estimate_bandwidth
- scikit-learn源码学习之cluster.MeanShift
- Scikit-learn源码学习之cluster.SpectralClustering
- scikit-learn源码学习之datasets.samples_generator.make_blobs
- 【Python学习】Scikit-learn之SVM
- scikit-learn学习之决策树算法
- scikit-learn学习之回归分析
- scikit-learn学习之贝叶斯分类算法
- scikit-learn学习之神经网络算法
- scikit-learn学习之SVM算法
- Scikit-learn机器学习实战之Kmeans
- scikit-learn学习之SVM算法
- scikit-learn学习之SVM算法
- 转载:scikit-learn学习之SVM算法
- scikit-learn学习之SVM算法
- 机器学习之scikit-learn初识
- scikit-learn学习之SVM算法
- scikit-learn学习之贝叶斯分类算法
- 求Sn=a+aa+aaa+aaaa+aaaaa的前5项之和
- svn老鸟使用git后对比
- Linux编程进程管理
- HTML 中 onclick 触发函数 xxx(param) 要传递对象参数的解决方法
- 关于面试
- scikit-learn源码学习之cluster.mean_shift.estimate_bandwidth
- httpclient 4.5.2 学习随笔(3)
- 利用conda在Hadoop-stream中使用定制python解释器
- Iptables实例练习
- sql group by 字段合并
- mongodb查询之find命令
- machine learning pre-learning (AI dairy 1)
- Html 复习3
- Mac版 office 2016 破解工具