k means算法入门

来源:互联网 发布:淘宝助理大头笔不显示 编辑:程序博客网 时间:2024/04/30 06:43

漫谈 Clustering (1): k-means

该文转自:http://blog.pluskid.org/?p=17

cluster_logo本文是“漫谈 Clustering 系列”中的第 1 篇,参见本系列的其他文章。

好久没有写 blog 了,一来是 blog 下线一段时间,而租 DreamHost 的事情又一直没弄好;二来是没有太多时间,天天都跑去实验室。现在主要折腾 Machine Learning 相关的东西,因为很多东西都不懂,所以平时也找一些资料来看。按照我以前的更新速度的话,这么长时间不写 blog 肯定是要被闷坏的,所以我也觉得还是不定期地整理一下自己了解到的东西,放在 blog 上,一来梳理总是有助于加深理解的,二来也算共享一下知识了。那么,还是从 clustering 说起吧。

Clustering 中文翻译作“聚类”,简单地说就是把相似的东西分到一组,同 Classification (分类)不同,对于一个 classifier ,通常需要你告诉它“这个东西被分为某某类”这样一些例子,理想情况下,一个 classifier 会从它得到的训练集中进行“学习”,从而具备对未知数据进行分类的能力,这种提供训练数据的过程通常叫做supervised learning (监督学习),而在聚类的时候,我们并不关心某一类是什么,我们需要实现的目标只是把相似的东西聚到一起,因此,一个聚类算法通常只需要知道如何计算相似 度就可以开始工作了,因此 clustering 通常并不需要使用训练数据进行学习,这在 Machine Learning 中被称作unsupervised learning (无监督学习)。

举一个简单的例子:现在有一群小学生,你要把他们分成几组,让组内的成员之间尽量相似一些,而组之间则差别大一些。最后分出怎样的结果,就取决于你对于“相似”的定义了,比如,你决定男生和男生是相似的,女生和女生也是相似的,而男生和女生之间则差别很大”,这样,你实际上是用一个可能取两个值“男”和“女”的离散变量来代表了原来的一个小学生,我们通常把这样的变量叫做“特征”。实际上,在这种情况下,所有的小学生都被映射到了两个点的其中一个上,已经很自然地形成了两个组,不需要专门再做聚类了。另一种可能是使用“身高”这个特征。我在读小学候,每周五在操场开会训话的时候会按照大家住的地方的地域和距离远近来列队,这样结束之后就可以结队回家了。除了让事物映射到一个单独的特征之外,一种常见的做法是同时提取 N 种特征,将它们放在一起组成一个 N 维向量,从而得到一个从原始数据集合到 N 维向量空间的映射——你总是需要显式地或者隐式地完成这样一个过程,因为许多机器学习的算法都需要工作在一个向量空间中。

那么让我们再回到 clustering 的问题上,暂且抛开原始数据是什么形式,假设我们已经将其映射到了一个欧几里德空间上,为了方便展示,就使用二维空间吧,如下图所示:

cluster

从数据点的大致形状可以看出它们大致聚为三个 cluster ,其中两个紧凑一些,剩下那个松散一些。我们的目的是为这些数据分组,以便能区分出属于不同的簇的数据,如果按照分组给它们标上不同的颜色,就是这个样子:

cluster

那么计算机要如何来完成这个任务呢?当然,计算机还没有高级到能够“通过形状大致看出来”,不过,对于这样的 N 维欧氏空间中的点进行聚类,有一个非常简单的经典算法,也就是本文标题中提到的 k-means 。在介绍 k-means 的具体步骤之前,让我们先来看看它对于需要进行聚类的数据的一个基本假设吧:对于每一个 cluster ,我们可以选出一个中心点 (center) ,使得该 cluster 中的所有的点到该中心点的距离小于到其他 cluster 的中心的距离。虽然实际情况中得到的数据并不能保证总是满足这样的约束,但这通常已经是我们所能达到的最好的结果,而那些误差通常是固有存在的或者问题本身的不可分性造成的。例如下图所示的两个高斯分布,从两个分布中随机地抽取一些数据点出来,混杂到一起,现在要让你将这些混杂在一起的数据点按照它们被生成的那个分布分开来:

gaussian

由于这两个分布本身有很大一部分重叠在一起了,例如,对于数据点 2.5 来说,它由两个分布产生的概率都是相等的,你所做的只能是一个猜测;稍微好一点的情况是 2 ,通常我们会将它归类为左边的那个分布,因为概率大一些,然而此时它由右边的分布生成的概率仍然是比较大的,我们仍然有不小的几率会猜错。而整个阴影部分是我们所能达到的最小的猜错的概率,这来自于问题本身的不可分性,无法避免。因此,我们将 k-means 所依赖的这个假设看作是合理的。

基于这样一个假设,我们再来导出 k-means 所要优化的目标函数:设我们一共有 N 个数据点需要分为 K 个 cluster ,k-means 要做的就是最小化

\displaystyle J = \sum_{n=1}^N\sum_{k=1}^K r_{nk} \|x_n-\mu_k\|^2

这个函数,其中 r_{nk} 在数据点 n 被归类到 cluster k 的时候为 1 ,否则为 0 。直接寻找r_{nk}\mu_k 来最小化J 并不容易,不过我们可以采取迭代的办法:先固定\mu_k ,选择最优的r_{nk} ,很容易看出,只要将数据点归类到离他最近的那个中心就能保证J 最小。下一步则固定r_{nk},再求最优的\mu_k。将J\mu_k 求导并令导数等于零,很容易得到J 最小的时候\mu_k 应该满足:

\displaystyle \mu_k=\frac{\sum_n r_{nk}x_n}{\sum_n r_{nk}}

亦即 \mu_k 的值应当是所有 cluster k 中的数据点的平均值。由于每一次迭代都是取到J 的最小值,因此J 只会不断地减小(或者不变),而不会增加,这保证了 k-means 最终会到达一个极小值。虽然 k-means 并不能保证总是能得到全局最优解,但是对于这样的问题,像 k-means 这种复杂度的算法,这样的结果已经是很不错的了。

下面我们来总结一下 k-means 算法的具体步骤:

  1. 选定 K 个中心 \mu_k 的初值。这个过程通常是针对具体的问题有一些启发式的选取方法,或者大多数情况下采用随机选取的办法。因为前面说过 k-means 并不能保证全局最优,而是否能收敛到全局最优解其实和初值的选取有很大的关系,所以有时候我们会多次选取初值跑 k-means ,并取其中最好的一次结果。
  2. 将每个数据点归类到离它最近的那个中心点所代表的 cluster 中。
  3. 用公式 \mu_k = \frac{1}{N_k}\sum_{j\in\text{cluster}_k}x_j 计算出每个 cluster 的新的中心点。
  4. 重复第二步,一直到迭代了最大的步数或者前后的 J 的值相差小于一个阈值为止。

按照这个步骤写一个 k-means 实现其实相当容易了,在 SciPy 或者 Matlab 中都已经包含了内置的 k-means 实现,不过为了看看 k-means 每次迭代的具体效果,我们不妨自己来实现一下,代码如下(需要安装SciPy 和matplotlib) :

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
[python] view plaincopyprint?
  1. <span style="color: rgb(128, 128, 128); font-style: italic;">#!/usr/bin/python</span> 
  2.   
  3. <span style="color: rgb(255, 119, 0); font-weight: bold;">from</span> <span style="color: rgb(220, 20, 60);">__future__</span> <span style="color: rgb(255, 119, 0); font-weight: bold;">import</span> with_statement 
  4. <span style="color: rgb(255, 119, 0); font-weight: bold;">import</span> <span style="color: rgb(220, 20, 60);">cPickle</span> <span style="color: rgb(255, 119, 0); font-weight: bold;">as</span> <span style="color: rgb(220, 20, 60);">pickle</span> 
  5. <span style="color: rgb(255, 119, 0); font-weight: bold;">from</span> matplotlib <span style="color: rgb(255, 119, 0); font-weight: bold;">import</span> pyplot 
  6. <span style="color: rgb(255, 119, 0); font-weight: bold;">from</span> numpy <span style="color: rgb(255, 119, 0); font-weight: bold;">import</span> zeros, <span style="color: rgb(220, 20, 60);">array</span>, tile 
  7. <span style="color: rgb(255, 119, 0); font-weight: bold;">from</span> scipy.<span style="color: black;">linalg</span> <span style="color: rgb(255, 119, 0); font-weight: bold;">import</span> norm 
  8. <span style="color: rgb(255, 119, 0); font-weight: bold;">import</span> numpy.<span style="color: black;">matlib</span> <span style="color: rgb(255, 119, 0); font-weight: bold;">as</span> ml 
  9. <span style="color: rgb(255, 119, 0); font-weight: bold;">import</span> <span style="color: rgb(220, 20, 60);">random</span> 
  10.   
  11. <span style="color: rgb(255, 119, 0); font-weight: bold;">def</span> kmeans<span style="color: black;">(</span>X, k, observer=<span style="color: rgb(0, 128, 0);">None</span>, threshold=1e-15, maxiter=<span style="color: rgb(255, 69, 0);">300</span><span style="color: black;">)</span>: 
  12.     N = <span style="color: rgb(0, 128, 0);">len</span><span style="color: black;">(</span>X<span style="color: black;">)</span> 
  13.     labels = zeros<span style="color: black;">(</span>N, dtype=<span style="color: rgb(0, 128, 0);">int</span><span style="color: black;">)</span> 
  14.     centers = <span style="color: rgb(220, 20, 60);">array</span><span style="color: black;">(</span><span style="color: rgb(220, 20, 60);">random</span>.<span style="color: black;">sample</span><span style="color: black;">(</span>X, k<span style="color: black;">)</span><span style="color: black;">)</span> 
  15.     <span style="color: rgb(0, 128, 0);">iter</span> = <span style="color: rgb(255, 69, 0);">0</span> 
  16.   
  17.     <span style="color: rgb(255, 119, 0); font-weight: bold;">def</span> calc_J<span style="color: black;">(</span><span style="color: black;">)</span>: 
  18.         <span style="color: rgb(0, 128, 0);">sum</span> = <span style="color: rgb(255, 69, 0);">0</span> 
  19.         <span style="color: rgb(255, 119, 0); font-weight: bold;">for</span> i <span style="color: rgb(255, 119, 0); font-weight: bold;">in</span> <span style="color: rgb(0, 128, 0);">xrange</span><span style="color: black;">(</span>N<span style="color: black;">)</span>: 
  20.             <span style="color: rgb(0, 128, 0);">sum</span> += norm<span style="color: black;">(</span>X<span style="color: black;">[</span>i<span style="color: black;">]</span>-centers<span style="color: black;">[</span>labels<span style="color: black;">[</span>i<span style="color: black;">]</span><span style="color: black;">]</span><span style="color: black;">)</span> 
  21.         <span style="color: rgb(255, 119, 0); font-weight: bold;">return</span> <span style="color: rgb(0, 128, 0);">sum</span> 
  22.   
  23.     <span style="color: rgb(255, 119, 0); font-weight: bold;">def</span> distmat<span style="color: black;">(</span>X, Y<span style="color: black;">)</span>: 
  24.         n = <span style="color: rgb(0, 128, 0);">len</span><span style="color: black;">(</span>X<span style="color: black;">)</span> 
  25.         m = <span style="color: rgb(0, 128, 0);">len</span><span style="color: black;">(</span>Y<span style="color: black;">)</span> 
  26.         xx = ml.<span style="color: rgb(0, 128, 0);">sum</span><span style="color: black;">(</span>X<span style="color: rgb(102, 204, 102);">*</span>X, axis=<span style="color: rgb(255, 69, 0);">1</span><span style="color: black;">)</span> 
  27.         yy = ml.<span style="color: rgb(0, 128, 0);">sum</span><span style="color: black;">(</span>Y<span style="color: rgb(102, 204, 102);">*</span>Y, axis=<span style="color: rgb(255, 69, 0);">1</span><span style="color: black;">)</span> 
  28.         xy = ml.<span style="color: black;">dot</span><span style="color: black;">(</span>X, Y.<span style="color: black;">T</span><span style="color: black;">)</span> 
  29.   
  30.         <span style="color: rgb(255, 119, 0); font-weight: bold;">return</span> tile<span style="color: black;">(</span>xx, <span style="color: black;">(</span>m, <span style="color: rgb(255, 69, 0);">1</span><span style="color: black;">)</span><span style="color: black;">)</span>.<span style="color: black;">T</span>+tile<span style="color: black;">(</span>yy, <span style="color: black;">(</span>n, <span style="color: rgb(255, 69, 0);">1</span><span style="color: black;">)</span><span style="color: black;">)</span> - <span style="color: rgb(255, 69, 0);">2</span><span style="color: rgb(102, 204, 102);">*</span>xy 
  31.   
  32.     Jprev = calc_J<span style="color: black;">(</span><span style="color: black;">)</span> 
  33.     <span style="color: rgb(255, 119, 0); font-weight: bold;">while</span> <span style="color: rgb(0, 128, 0);">True</span>: 
  34.         <span style="color: rgb(128, 128, 128); font-style: italic;"># notify the observer</span> 
  35.         <span style="color: rgb(255, 119, 0); font-weight: bold;">if</span> observer <span style="color: rgb(255, 119, 0); font-weight: bold;">is</span> <span style="color: rgb(255, 119, 0); font-weight: bold;">not</span> <span style="color: rgb(0, 128, 0);">None</span>: 
  36.             observer<span style="color: black;">(</span><span style="color: rgb(0, 128, 0);">iter</span>, labels, centers<span style="color: black;">)</span> 
  37.   
  38.         <span style="color: rgb(128, 128, 128); font-style: italic;"># calculate distance from x to each center</span> 
  39.         <span style="color: rgb(128, 128, 128); font-style: italic;"># distance_matrix is only available in scipy newer than 0.7</span> 
  40.         <span style="color: rgb(128, 128, 128); font-style: italic;"># dist = distance_matrix(X, centers)</span> 
  41.         dist = distmat<span style="color: black;">(</span>X, centers<span style="color: black;">)</span> 
  42.         <span style="color: rgb(128, 128, 128); font-style: italic;"># assign x to nearst center</span> 
  43.         labels = dist.<span style="color: black;">argmin</span><span style="color: black;">(</span>axis=<span style="color: rgb(255, 69, 0);">1</span><span style="color: black;">)</span> 
  44.         <span style="color: rgb(128, 128, 128); font-style: italic;"># re-calculate each center</span> 
  45.         <span style="color: rgb(255, 119, 0); font-weight: bold;">for</span> j <span style="color: rgb(255, 119, 0); font-weight: bold;">in</span> <span style="color: rgb(0, 128, 0);">range</span><span style="color: black;">(</span>k<span style="color: black;">)</span>: 
  46.             idx_j = <span style="color: black;">(</span>labels == j<span style="color: black;">)</span>.<span style="color: black;">nonzero</span><span style="color: black;">(</span><span style="color: black;">)</span> 
  47.             centers<span style="color: black;">[</span>j<span style="color: black;">]</span> = X<span style="color: black;">[</span>idx_j<span style="color: black;">]</span>.<span style="color: black;">mean</span><span style="color: black;">(</span>axis=<span style="color: rgb(255, 69, 0);">0</span><span style="color: black;">)</span> 
  48.   
  49.         J = calc_J<span style="color: black;">(</span><span style="color: black;">)</span> 
  50.         <span style="color: rgb(0, 128, 0);">iter</span> += <span style="color: rgb(255, 69, 0);">1</span> 
  51.   
  52.         <span style="color: rgb(255, 119, 0); font-weight: bold;">if</span> Jprev-J <span style="color: rgb(102, 204, 102);"><</span> threshold: 
  53.             <span style="color: rgb(255, 119, 0); font-weight: bold;">break</span> 
  54.         Jprev = J 
  55.         <span style="color: rgb(255, 119, 0); font-weight: bold;">if</span> <span style="color: rgb(0, 128, 0);">iter</span> <span style="color: rgb(102, 204, 102);">></span>= maxiter: 
  56.             <span style="color: rgb(255, 119, 0); font-weight: bold;">break</span> 
  57.   
  58.     <span style="color: rgb(128, 128, 128); font-style: italic;"># final notification</span> 
  59.     <span style="color: rgb(255, 119, 0); font-weight: bold;">if</span> observer <span style="color: rgb(255, 119, 0); font-weight: bold;">is</span> <span style="color: rgb(255, 119, 0); font-weight: bold;">not</span> <span style="color: rgb(0, 128, 0);">None</span>: 
  60.         observer<span style="color: black;">(</span><span style="color: rgb(0, 128, 0);">iter</span>, labels, centers<span style="color: black;">)</span> 
  61.   
  62. <span style="color: rgb(255, 119, 0); font-weight: bold;">if</span> __name__ == <span style="color: rgb(72, 61, 139);">'__main__'</span>: 
  63.     <span style="color: rgb(128, 128, 128); font-style: italic;"># load previously generated points</span> 
  64.     <span style="color: rgb(255, 119, 0); font-weight: bold;">with</span> <span style="color: rgb(0, 128, 0);">open</span><span style="color: black;">(</span><span style="color: rgb(72, 61, 139);">'cluster.pkl'</span><span style="color: black;">)</span> <span style="color: rgb(255, 119, 0); font-weight: bold;">as</span> inf: 
  65.         samples = <span style="color: rgb(220, 20, 60);">pickle</span>.<span style="color: black;">load</span><span style="color: black;">(</span>inf<span style="color: black;">)</span> 
  66.     N = <span style="color: rgb(255, 69, 0);">0</span> 
  67.     <span style="color: rgb(255, 119, 0); font-weight: bold;">for</span> smp <span style="color: rgb(255, 119, 0); font-weight: bold;">in</span> samples: 
  68.         N += <span style="color: rgb(0, 128, 0);">len</span><span style="color: black;">(</span>smp<span style="color: black;">[</span><span style="color: rgb(255, 69, 0);">0</span><span style="color: black;">]</span><span style="color: black;">)</span> 
  69.     X = zeros<span style="color: black;">(</span><span style="color: black;">(</span>N, <span style="color: rgb(255, 69, 0);">2</span><span style="color: black;">)</span><span style="color: black;">)</span> 
  70.     idxfrm = <span style="color: rgb(255, 69, 0);">0</span> 
  71.     <span style="color: rgb(255, 119, 0); font-weight: bold;">for</span> i <span style="color: rgb(255, 119, 0); font-weight: bold;">in</span> <span style="color: rgb(0, 128, 0);">range</span><span style="color: black;">(</span><span style="color: rgb(0, 128, 0);">len</span><span style="color: black;">(</span>samples<span style="color: black;">)</span><span style="color: black;">)</span>: 
  72.         idxto = idxfrm + <span style="color: rgb(0, 128, 0);">len</span><span style="color: black;">(</span>samples<span style="color: black;">[</span>i<span style="color: black;">]</span><span style="color: black;">[</span><span style="color: rgb(255, 69, 0);">0</span><span style="color: black;">]</span><span style="color: black;">)</span> 
  73.         X<span style="color: black;">[</span>idxfrm:idxto, <span style="color: rgb(255, 69, 0);">0</span><span style="color: black;">]</span> = samples<span style="color: black;">[</span>i<span style="color: black;">]</span><span style="color: black;">[</span><span style="color: rgb(255, 69, 0);">0</span><span style="color: black;">]</span> 
  74.         X<span style="color: black;">[</span>idxfrm:idxto, <span style="color: rgb(255, 69, 0);">1</span><span style="color: black;">]</span> = samples<span style="color: black;">[</span>i<span style="color: black;">]</span><span style="color: black;">[</span><span style="color: rgb(255, 69, 0);">1</span><span style="color: black;">]</span> 
  75.         idxfrm = idxto 
  76.   
  77.     <span style="color: rgb(255, 119, 0); font-weight: bold;">def</span> observer<span style="color: black;">(</span><span style="color: rgb(0, 128, 0);">iter</span>, labels, centers<span style="color: black;">)</span>: 
  78.         <span style="color: rgb(255, 119, 0); font-weight: bold;">print</span> <span style="color: rgb(72, 61, 139);">"iter %d."</span> <span style="color: rgb(102, 204, 102);">%</span> <span style="color: rgb(0, 128, 0);">iter</span> 
  79.         colors = <span style="color: rgb(220, 20, 60);">array</span><span style="color: black;">(</span><span style="color: black;">[</span><span style="color: black;">[</span><span style="color: rgb(255, 69, 0);">1</span>, <span style="color: rgb(255, 69, 0);">0</span>, <span style="color: rgb(255, 69, 0);">0</span><span style="color: black;">]</span>, <span style="color: black;">[</span><span style="color: rgb(255, 69, 0);">0</span>, <span style="color: rgb(255, 69, 0);">1</span>, <span style="color: rgb(255, 69, 0);">0</span><span style="color: black;">]</span>, <span style="color: black;">[</span><span style="color: rgb(255, 69, 0);">0</span>, <span style="color: rgb(255, 69, 0);">0</span>, <span style="color: rgb(255, 69, 0);">1</span><span style="color: black;">]</span><span style="color: black;">]</span><span style="color: black;">)</span> 
  80.         pyplot.<span style="color: black;">plot</span><span style="color: black;">(</span>hold=<span style="color: rgb(0, 128, 0);">False</span><span style="color: black;">)</span>  <span style="color: rgb(128, 128, 128); font-style: italic;"># clear previous plot</span> 
  81.         pyplot.<span style="color: black;">hold</span><span style="color: black;">(</span><span style="color: rgb(0, 128, 0);">True</span><span style="color: black;">)</span> 
  82.   
  83.         <span style="color: rgb(128, 128, 128); font-style: italic;"># draw points</span> 
  84.         data_colors=<span style="color: black;">[</span>colors<span style="color: black;">[</span>lbl<span style="color: black;">]</span> <span style="color: rgb(255, 119, 0); font-weight: bold;">for</span> lbl <span style="color: rgb(255, 119, 0); font-weight: bold;">in</span> labels<span style="color: black;">]</span> 
  85.         pyplot.<span style="color: black;">scatter</span><span style="color: black;">(</span>X<span style="color: black;">[</span>:, <span style="color: rgb(255, 69, 0);">0</span><span style="color: black;">]</span>, X<span style="color: black;">[</span>:, <span style="color: rgb(255, 69, 0);">1</span><span style="color: black;">]</span>, c=data_colors, alpha=<span style="color: rgb(255, 69, 0);">0.5</span><span style="color: black;">)</span> 
  86.         <span style="color: rgb(128, 128, 128); font-style: italic;"># draw centers</span> 
  87.         pyplot.<span style="color: black;">scatter</span><span style="color: black;">(</span>centers<span style="color: black;">[</span>:, <span style="color: rgb(255, 69, 0);">0</span><span style="color: black;">]</span>, centers<span style="color: black;">[</span>:, <span style="color: rgb(255, 69, 0);">1</span><span style="color: black;">]</span>, s=<span style="color: rgb(255, 69, 0);">200</span>, c=colors<span style="color: black;">)</span> 
  88.   
  89.         pyplot.<span style="color: black;">savefig</span><span style="color: black;">(</span><span style="color: rgb(72, 61, 139);">'kmeans/iter_%02d.png'</span> <span style="color: rgb(102, 204, 102);">%</span> <span style="color: rgb(0, 128, 0);">iter</span>, format=<span style="color: rgb(72, 61, 139);">'png'</span><span style="color: black;">)</span> 
  90.   
  91.     kmeans<span style="color: black;">(</span>X, <span style="color: rgb(255, 69, 0);">3</span>, observer=observer<span style="color: black;">)</span> 

代码有些长,不过因为用 Python 来做这个事情确实不如 Matlab 方便,实际的 k-means 的代码只是 41 到 47 行。首先 3 个中心点被随机初始化,所有的数据点都还没有进行聚类,默认全部都标记为红色,如下图所示:

iter_00

然后进入第一次迭代:按照初始的中心点位置为每个数据点着上颜色,这是代码中第 41 到 43 行所做的工作,然后 45 到 47 行重新计算 3 个中心点,结果如下图所示:

iter_01

可以看到,由于初始的中心点是随机选的,这样得出来的结果并不是很好,接下来是下一次迭代的结果:

iter_02

可以看到大致形状已经出来了。再经过两次迭代之后,基本上就收敛了,最终结果如下:

iter_04

不过正如前面所说的那样 k-means 也并不是万能的,虽然许多时候都能收敛到一个比较好的结果,但是也有运气不好的时候会收敛到一个让人不满意的局部最优解,例如选用下面这几个初始中心点:

iter_00_bad

最终会收敛到这样的结果:

iter_03_bad

不得不承认这并不是很好的结果。不过其实大多数情况下 k-means 给出的结果都还是很令人满意的,算是一种简单高效应用广泛的 clustering 方法。

原创粉丝点击