k-means聚类算法

来源:互联网 发布:sql重庆培训 编辑:程序博客网 时间:2024/05/17 03:08
开始之前先介绍下什么是簇识别。
簇识别:簇识别给出聚类结果的含义。假定有一些数据,现在将相似数据归到一起,簇识别会告诉我们这些簇到底都是些什么。
k-均值是发现给定数据集的K个簇的算法。簇个数是用户给定的,每一个簇通过其质心,即簇中所有点的中心来描述。

k-means聚类算法的评价:
优点:容易实现
缺点:可能收敛到局部最小值,在大规模数据集上收敛较慢。
使用数据类型:数值型数据。

工作流程:
首先,随机确定K个初始点作为质心。然后将数据集中的每个点分配到一个簇中,具体来讲,为每个点找距其最近的质心。并将其分配给该质心所对应的簇。这一步完成之后,每个簇的质心更新为该簇所有点的平均值,然后使用心的质心迭代聚类过程,直到收敛为止。
(这里的收敛指质心的不在改变,如果迭代后的质心位置与上一次相比如果超过了设定的差,那么就说明还未收敛,需要继续迭代)

伪代码:
1.创建k个点作为起始质心(经常是随机选择)
2.当任意一个点的簇分配结果发生改变时
        对数据集中的每个数据点
                对每个质心
                        计算质心与数据点之间的距离
                将数据点分配到距其最近的簇
        对每一个簇,计算簇中所有点的均值并将均值作为质心
3.使用新的质心来迭代2,直到收敛为止。
一般流程:
收集数据:使用任意方法
准备数据:需要数值型数据来计算距离,也可以将标称型数据映射为二值型数据再用于距离计算。
分析数据:使用任意方法
训练算法:不适用于无监督学习,即无监督学习没有训练过程。
测试算法:应用聚类算法、观察结果。可以使用量化的错误指标如误差平方和来评价算法的结果
使用算法:可以用于所希望的任何应用。通常情况下,簇质心可以代表整个簇的数据来做出决定。

其他问题
离群点的处理
离群点可能过度影响簇的发现,导致簇的最终发布会与我们的预想有较大出入,所以提前发现并剔除离群点是有必要的。
利用方差来剔除离群点,结果显示效果非常好。下文java实现的算法并未剔除利群点,如果有需要,可以在自己的算法实现中加入这一步。
簇分裂和簇合并(代码下面会做具体的介绍)
使用较大的K,往往会使得聚类的结果看上去更加合理,但很多情况下,我们并不想增加簇的个数。
这时可以交替采用簇分裂和簇合并。这种方式可以避开局部极小,并且能够得到具有期望个数簇的结果。

代码:
  1. import java.util.ArrayList;
  2. import java.util.EmptyStackException;
  3. import java.util.Random;
  4. import java.util.Scanner;
  5. /**
  6. * Created by wubo on 2016/10/27.
  7. */
  8. public class Kmeans {
  9. private int k;//簇的个数
  10. private int m;// 迭代次数
  11. private int dataSetLength;// 数据集元素个数,即数据集的长度
  12. private ArrayList<float[]> dataSet;// 数据集链表
  13. private ArrayList<float[]> center;// 中心链表
  14. private ArrayList<ArrayList<float[]>> cluster; // 簇
  15. private ArrayList<Float> jc;// 误差平方和,k越接近dataSetLength,误差越小
  16. private Random random;
  17. /**
  18. * 设置需分组的原始数据集
  19. *
  20. * @param dataSet
  21. */
  22. public void setDataSet(ArrayList<float[]> dataSet) {
  23. this.dataSet = dataSet;
  24. }
  25. /**
  26. * 获取结果分组
  27. *
  28. * @return 结果集
  29. */
  30. public ArrayList<ArrayList<float[]>> getCluster() {
  31. return cluster;
  32. }
  33. /**
  34. * 构造函数,传入需要分成的簇数量
  35. *
  36. * @param k
  37. * 簇数量,若k<=0时,设置为1,若k大于数据源的长度时,置为数据源的长度
  38. */
  39. public Kmeans(int k) {
  40. if (k <= 0) {
  41. k = 1;
  42. }
  43. this.k = k;
  44. }
  45. /**
  46. * 初始化
  47. */
  48. private void init() {
  49. m = 0;//初始化迭代次数
  50. random = new Random();
  51. if (dataSet == null || dataSet.size() == 0) {
  52. initDataSet();
  53. }
  54. dataSetLength = dataSet.size();//数据集长度
  55. if (k > dataSetLength) {
  56. k = dataSetLength;
  57. }
  58. center = initCenters();//初始化质心
  59. System.out.print("error");
  60. cluster = initCluster();//初始化簇集合
  61. jc = new ArrayList<Float>();//误差平方和
  62. }
  63. /**
  64. * 如果调用者未初始化数据集,则采用内部测试数据集
  65. */
  66. private void initDataSet() {
  67. dataSet = new ArrayList<float[]>();
  68. // 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0
  69. float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 },
  70. { 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },
  71. { 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } };
  72. for (int i = 0; i < dataSetArray.length; i++) {
  73. dataSet.add(dataSetArray[i]);
  74. System.out.print("["+dataSetArray[i][0]+","+dataSetArray[i][1]+"]");
  75. }
  76. }
  77. /**
  78. * 初始化中心数据链表,分成多少簇就有多少个中心点
  79. *
  80. * @return 中心点集
  81. */
  82. private ArrayList<float[]> initCenters() {
  83. ArrayList<float[]> center = new ArrayList<float[]>();
  84. int[] randoms = new int[k];//创建存放质心的数组
  85. boolean flag;
  86. int temp = random.nextInt(dataSetLength);//随机一个质点
  87. randoms[0] = temp;//质点加入数组中
  88. for (int i = 1; i < k; i++) {
  89. flag=false;
  90. while(!flag){
  91. temp=random.nextInt(dataSetLength);
  92. int j=0;
  93. for (j=0;j<i;j++){
  94. if (randoms[j]==temp){
  95. break;
  96. }
  97. }
  98. if(j==i){
  99. flag=true;
  100. }
  101. }
  102. randoms[i] = temp;
  103. }
  104. //测试随机数生成情况
  105. for(int i=0;i<k;i++)
  106. {
  107. System.out.println("test1:randoms["+i+"]="+randoms[i]);
  108. }
  109. for (int i = 0; i < k; i++) {
  110. center.add(dataSet.get(randoms[i]));// 生成初始化中心链表
  111. }
  112. return center;
  113. }
  114. /**
  115. * 初始化簇集合
  116. *
  117. * @return 一个分为k个簇的空数据的簇集合
  118. */
  119. private ArrayList<ArrayList<float[]>> initCluster() {
  120. ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>();
  121. for (int i = 0; i < k; i++) {
  122. cluster.add(new ArrayList<float[]>());
  123. }
  124. return cluster;
  125. }
  126. /**
  127. * 计算两个点之间的距离
  128. *
  129. * @param element
  130. * 点1
  131. * @param center
  132. * 点2
  133. * @return 距离
  134. */
  135. private float distance(float[] element, float[] center) {
  136. float distance = 0.0f;
  137. float x = element[0] - center[0];
  138. float y = element[1] - center[1];
  139. float z = x * x + y * y;
  140. distance = (float) Math.sqrt(z);
  141. return distance;
  142. }
  143. /**
  144. * 获取距离集合中最小距离的位置
  145. *
  146. * @param distance
  147. * 距离数组
  148. * @return 最小距离在距离数组中的位置
  149. */
  150. private int minDistance(float[] distance) {
  151. float minDistance = distance[0];
  152. int minLocation = 0;
  153. for (int i = 1; i < distance.length; i++) {
  154. if (distance[i] < minDistance) {
  155. minDistance = distance[i];
  156. minLocation = i;
  157. } else if (distance[i] == minDistance) // 如果相等,随机返回一个位置
  158. {
  159. if (random.nextInt(10) < 5) {
  160. minLocation = i;
  161. }
  162. }
  163. }
  164. return minLocation;
  165. }
  166. /**
  167. * 核心,将当前元素放到最小距离中心相关的簇中
  168. */
  169. private void clusterSet() {
  170. float[] distance = new float[k];
  171. for (int i = 0; i < dataSetLength; i++) {
  172. for (int j = 0; j < k; j++) {
  173. distance[j] = distance(dataSet.get(i), center.get(j));//计算出当前点到每个簇质心的距离
  174. }
  175. int minLocation = minDistance(distance);//取得的最小距离的簇在数组中的编号
  176. cluster.get(minLocation).add(dataSet.get(i));// 核心,将当前元素放到最小距离中心相关的簇中
  177. }
  178. }
  179. /**
  180. * 求两点误差平方的方法
  181. *
  182. * @param element
  183. * 点1
  184. * @param center
  185. * 点2
  186. * @return 误差平方
  187. */
  188. private float errorSquare(float[] element, float[] center) {
  189. float x = element[0] - center[0];
  190. float y = element[1] - center[1];
  191. float errSquare = x * x + y * y;
  192. return errSquare;
  193. }
  194. /**
  195. * 计算误差平方和准则函数方法,所有点到质心的误差平方和之和
  196. */
  197. private void countRule() {
  198. float jcF = 0;
  199. for (int i = 0; i < cluster.size(); i++) {
  200. for (int j = 0; j < cluster.get(i).size(); j++) {
  201. jcF += errorSquare(cluster.get(i).get(j), center.get(i));
  202. }
  203. }
  204. jc.add(jcF);
  205. }
  206. /**
  207. * 设置新的簇中心方法,平面的质心的求解算法,即求x,y的平均值
  208. */
  209. private void setNewCenter() {
  210. for (int i = 0; i < k; i++) {
  211. int n = cluster.get(i).size();
  212. if (n != 0) {
  213. float[] newCenter = { 0, 0 };
  214. for (int j = 0; j < n; j++) {
  215. newCenter[0] += cluster.get(i).get(j)[0];
  216. newCenter[1] += cluster.get(i).get(j)[1];
  217. }
  218. // 设置一个平均值
  219. newCenter[0] = newCenter[0] / n;
  220. newCenter[1] = newCenter[1] / n;
  221. center.set(i, newCenter);
  222. }
  223. }
  224. }
  225. /**
  226. * 打印数据,测试用
  227. *
  228. * @param dataArray
  229. * 数据集
  230. * @param dataArrayName
  231. * 数据集名称
  232. */
  233. public void printDataArray(ArrayList<float[]> dataArray,
  234. String dataArrayName) {
  235. for (int i = 0; i < dataArray.size(); i++) {
  236. System.out.println("print:" + dataArrayName + "[" + i + "]={"
  237. + dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");
  238. }
  239. System.out.println("===================================");
  240. }
  241. /**
  242. * Kmeans算法核心过程方法
  243. */
  244. private void kmeans() {
  245. init();
  246. System.out.print("finish init");
  247. //printDataArray(dataSet,"initDataSet");
  248. // printDataArray(center,"initCenter");
  249. // 循环分组,直到误差不变为止
  250. while (true) {
  251. clusterSet();
  252. countRule();//计算误差平方和
  253. // 误差不变了,分组完成
  254. if (m != 0) {//m为迭代次数
  255. if (Math.abs(jc.get(m) - jc.get(m - 1)) <1E-10) {
  256. System.out.print(jc.get(m) - jc.get(m - 1));
  257. System.out.print("迭代完成\n");
  258. break;
  259. }
  260. }
  261. printDataArray(center,"newCenter");
  262. setNewCenter();
  263. m++;
  264. cluster.clear();
  265. cluster = initCluster();
  266. }
  267. printDataArray(center,"newCenter");
  268. System.out.println("note:the times of repeat:m="+m);//输出迭代次数
  269. }
  270. /**
  271. * 执行算法
  272. */
  273. public void execute() {
  274. long startTime = System.currentTimeMillis();
  275. System.out.println("kmeans begins");
  276. kmeans();
  277. long endTime = System.currentTimeMillis();
  278. System.out.println("kmeans running time=" + (endTime - startTime)
  279. + "ms");
  280. System.out.println("kmeans ends");
  281. System.out.println();
  282. }
  283. }
  1. import sun.security.jgss.krb5.Krb5NameElement;
  2. /**
  3. * Created by wubo on 2016/10/28.
  4. */
  5. public class KmeansTest {
  6. public static void main(String[] args) {
  7. Kmeans kmeans =new Kmeans(2);
  8. kmeans.execute();
  9. }
  10. }

使用后处理来提高聚类性能:
       k-means算法中簇的数目k是一个用户预先定义的参数,那么用户如何知道K的选择是否正确?如何才能知道簇的选择比较好?
上文中未做介绍,下文来细说下。
       K-均值算法收敛但聚类效果较差的原因是,k-均值算法收敛到了局部最小值而非全局最小值。(局部最小值指结果还可以但并非最好结果,全局最小值是最好的结果)。
       一种用于度量聚类效果的指标是SSE(Sum of Squard Error,误差平方和)。SSE值越小表示数据点越接近与它们的质心,聚类效果也越好。因为对误差去了平方,因此更加重视那些远离中心的点。一种肯定可以降低SSE值的方法是增加簇的个数,但是这违背了聚类的目标。聚类的目标是在保持簇数目不变的情况下提高簇的质量。
       另一种方法就是将具有最大的SSE值的簇划分为两个簇。具体实现时可以将最大簇包含的点过滤出来,并在这些点上运行K-均值算法,其中的k设为2。同时为了保证簇总数不变,可以将某两个簇合并。
       那么如何选择要合并的两个簇,有两种可以量化的办法:合并最近的质心,或者合并两个使得SSE增幅最小的质心。第一种思路是计算所有质心之间的距离,第二种需要合并两个簇然后计算总SSE值,必需在所有可能的两个簇上重复上述处理过程,直到找出合并最佳的两个簇为止。 
        为此,在k-均值算法的基础上有实现了该技术的 二分 k-均值算法

0 0
原创粉丝点击