基于连通图的分裂聚类算法

来源:互联网 发布:致青春知乎 编辑:程序博客网 时间:2024/05/07 13:16

参考文献:基于连通图动态分裂的聚类算法.作者:邓健爽 郑启伦 彭宏 邓维维(华南理工大学计算机科学与工程学院,广东广州510640)

我的算法库:https://github.com/linyiqun/lyq-algorithms-lib 

算法介绍

从文章的标题可以看出,今天我所介绍的算法又是一个聚类算法,不过他比较特殊,用到了图方面的知识,而且是一种动态的算法,与BIRCH算法一样,他也是一种层次聚类的算法,BIRCH算法是属于那种,一步步慢慢合并从而形成最终的聚类结果,而本文所描述的算法则恰巧相反,通过不断分裂直到最后不能在分裂下去为止,事实上,通过分裂实现的聚类的算法并不常见,平时说的比较多的这种算法就是chameleon算法,基于连通图的分裂聚类算法与此很类似,但又有少许的不同。首先声明这个算法的提出是出自于某篇学术论文,人家提出了这个思想,我去做了一下学习和实现,所以在这里分享一下。

算法的原理

算法的大的方向的阶段为2个阶段,第一个是根据坐标点的位置距离关系形成连通图。第二个阶段是将形成的多个连通图,进行逐一的分裂。图形化的表示过程如下,方便大家理解。



这么看来,和chameleon算法还是非常类似的。第一个步骤可以采用我的上一篇文章中用到的dbscan算法的思路,去深度优先搜索尽可能大的范围的点集,然后再用边将他们连接起来。这个如果不清楚的话,可以点击我的上一篇文章进行查阅。在这里会给定一个距离阈值l,这样就会生出基于距离l的连通图集。在上图中,就生成了2个连通图集,上面的一个和下面的一个。下面主要讲一下分裂的机理和过程,这也是整个算法的创新点和难点所在。

分裂的原理

分裂的原理采用了类似于扁担挑重物的形式,每一条边类似于一个扁担,坐标点在这里就是一个个的重物,如果扁担的2端的重物都非常重,那么扁担就容易断,于是就会分裂。举个例子如下:


但是我们要怎么去衡量一条边能不能够被分裂的标准呢,在这里定义了2个概念,承受系数t和分裂阈值landa。承受因为t就是要分裂的2部分中的较轻的一端的重量/连接2部分的边数,意思就是平均每条边所要承受的点的个数。公式如下:

t=min{W1,W2}/n,W1,W2为分割后的2部分的点的个数,n为2连接2部分的边的数量。

理解了这个,就很好分裂阈值了,分裂阈值就是当前针对全部的连通图,每条边的承受状况指数,你可以理解为就是总坐标点数/总边数。但是我们在这里采用更科学的方式进行计算,大意还是如上面描述的那样:


注意这里的x和y的关系,与上面的已经不一样了,至于这个公式为什么就不比刚刚的那个要好,就不是本文所论述的范畴了。截止到这里,我们就能得出一个比较条件了,就是当根据某条边进行分割的时候,如果此时计算出来的承受系数大于等于分裂阈值的时候,就表明此边是可以被分割掉的,也就是说,此时的连通图可以继续被拆分掉。算法的伪代码如下:

main()

{

Result r;

for-each每个连通图G

{

Graph[] graphs;

graphs = splitGraph(G)

r.add(graphs)

}

}

splitGraph(连通图G)

{

//默认不能被划分

int canDivied=0;

for(m从2到Pnum/2) //Pnum为连通图中的坐标点数

{

//将原图进行分割

Graph2 subGraph2 =G,removeM();

Graph1 subGraph1 = G;

//此函数会判断承受系数是否大于此时的分裂阈值

if(canDivide(subGraph1, subGraph2))

{

//改变标签

canDivied=1;

//继续递归的划分子图1,子图2

split(subGraph1);

split(subGraph2);

}


if(canDivided == 0)

{

//说明不能在分割了,为一个聚类,加入结果集中

addToResult()

}

}

上面的伪代码是自己想出来的,与论文原文所描述略有不同,我对其中加入了个人的思考和改进的地方,首先一点都是一样的,就是分裂一定是递归进行的,后一次的划分是建立在前一次划分的基础上进行的。以上就是第二阶段所做的事情,然后再次把目标转向问题本身,因为此问题是基于连通图的,所以在这里我用了边的数组表示,他其实是一个无向图,我还是用了id对id的形式来表示是否存在连接2点的边。下面也是算法的代码实现,也非常的重要哦(请仔细看里面的一些实现细节)。

算法的实现

首先是数据的点输入graphData.txt(格式:id  横坐标 纵坐标):

[java] view plaincopyprint?
  1. 0 1 12  
  2. 1 3 9  
  3. 2 3 12  
  4. 3 4 10  
  5. 4 4 4  
  6. 5 4 1  
  7. 6 6 1  
  8. 7 6 3  
  9. 8 6 9  
  10. 9 8 3  
  11. 10 8 10  
  12. 11 9 2  
  13. 12 9 11  
  14. 13 10 9  
  15. 14 11 12  
总共15个点。

坐标点类Point.java:

[java] view plaincopyprint?
  1. package DataMining_CABDDCC;  
  2.   
  3.   
  4.   
  5. /** 
  6.  * 坐标点类 
  7.  * @author lyq 
  8.  * 
  9.  */  
  10. public class Point implements Comparable<Point>{  
  11.     //坐标点id号,id号唯一  
  12.     int id;  
  13.     //坐标横坐标  
  14.     Integer x;  
  15.     //坐标纵坐标  
  16.     Integer y;  
  17.     //坐标点是否已经被访问(处理)过,在生成连通子图的时候用到  
  18.     boolean isVisited;  
  19.       
  20.     public Point(String id, String x, String y){  
  21.         this.id = Integer.parseInt(id);  
  22.         this.x = Integer.parseInt(x);  
  23.         this.y = Integer.parseInt(y);  
  24.     }  
  25.       
  26.     /** 
  27.      * 计算当前点与制定点之间的欧式距离 
  28.      *  
  29.      * @param p 
  30.      *            待计算聚类的p点 
  31.      * @return 
  32.      */  
  33.     public double ouDistance(Point p) {  
  34.         double distance = 0;  
  35.   
  36.         distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y)  
  37.                 * (this.y - p.y);  
  38.         distance = Math.sqrt(distance);  
  39.   
  40.         return distance;  
  41.     }  
  42.       
  43.     /** 
  44.      * 判断2个坐标点是否为用个坐标点 
  45.      *  
  46.      * @param p 
  47.      *            待比较坐标点 
  48.      * @return 
  49.      */  
  50.     public boolean isTheSame(Point p) {  
  51.         boolean isSamed = false;  
  52.   
  53.         if (this.x == p.x && this.y == p.y) {  
  54.             isSamed = true;  
  55.         }  
  56.   
  57.         return isSamed;  
  58.     }  
  59.   
  60.     @Override  
  61.     public int compareTo(Point p) {  
  62.         if(this.x.compareTo(p.x) != 0){  
  63.             return this.x.compareTo(p.x);  
  64.         }else{  
  65.             //如果在x坐标相等的情况下比较y坐标  
  66.             return this.y.compareTo(p.y);  
  67.         }  
  68.     }  
  69. }  
连通图类Graph.java:

[java] view plaincopyprint?
  1. package DataMining_CABDDCC;  
  2.   
  3. import java.util.ArrayList;  
  4. import java.util.Collections;  
  5.   
  6. /** 
  7.  * 连通图类 
  8.  *  
  9.  * @author lyq 
  10.  *  
  11.  */  
  12. public class Graph {  
  13.     // 坐标点之间的连接属性,括号内为坐标id号  
  14.     int[][] edges;  
  15.     // 连通图内的坐标点数  
  16.     ArrayList<Point> points;  
  17.     // 此图下分割后的聚类子图  
  18.     ArrayList<ArrayList<Point>> clusters;  
  19.   
  20.     public Graph(int[][] edges) {  
  21.         this.edges = edges;  
  22.         this.points = getPointByEdges(edges);  
  23.     }  
  24.   
  25.     public Graph(int[][] edges, ArrayList<Point> points) {  
  26.         this.edges = edges;  
  27.         this.points = points;  
  28.     }  
  29.   
  30.     public int[][] getEdges() {  
  31.         return edges;  
  32.     }  
  33.   
  34.     public void setEdges(int[][] edges) {  
  35.         this.edges = edges;  
  36.     }  
  37.   
  38.     public ArrayList<Point> getPoints() {  
  39.         return points;  
  40.     }  
  41.   
  42.     public void setPoints(ArrayList<Point> points) {  
  43.         this.points = points;  
  44.     }  
  45.   
  46.     /** 
  47.      * 根据距离阈值做连通图的划分,构成连通图集 
  48.      *  
  49.      * @param length 
  50.      *            距离阈值 
  51.      * @return 
  52.      */  
  53.     public ArrayList<Graph> splitGraphByLength(int length) {  
  54.         int[][] edges;  
  55.         Graph tempGraph;  
  56.         ArrayList<Graph> graphs = new ArrayList<>();  
  57.   
  58.         for (Point p : points) {  
  59.             if (!p.isVisited) {  
  60.                 // 括号中的下标为id号  
  61.                 edges = new int[points.size()][points.size()];  
  62.                 dfsExpand(p, length, edges);  
  63.   
  64.                 tempGraph = new Graph(edges);  
  65.                 graphs.add(tempGraph);  
  66.             } else {  
  67.                 continue;  
  68.             }  
  69.         }  
  70.   
  71.         return graphs;  
  72.     }  
  73.   
  74.     /** 
  75.      * 深度优先方式扩展连通图 
  76.      *  
  77.      * @param points 
  78.      *            需要继续深搜的坐标点 
  79.      * @param length 
  80.      *            距离阈值 
  81.      * @param edges 
  82.      *            边数组 
  83.      */  
  84.     private void dfsExpand(Point point, int length, int edges[][]) {  
  85.         int id1 = 0;  
  86.         int id2 = 0;  
  87.         double distance = 0;  
  88.         ArrayList<Point> tempPoints;  
  89.   
  90.         // 如果处理过了,则跳过  
  91.         if (point.isVisited) {  
  92.             return;  
  93.         }  
  94.   
  95.         id1 = point.id;  
  96.         point.isVisited = true;  
  97.         tempPoints = new ArrayList<>();  
  98.         for (Point p2 : points) {  
  99.             id2 = p2.id;  
  100.   
  101.             if (id1 == id2) {  
  102.                 continue;  
  103.             } else {  
  104.                 distance = point.ouDistance(p2);  
  105.                 if (distance <= length) {  
  106.                     edges[id1][id2] = 1;  
  107.                     edges[id2][id1] = 1;  
  108.   
  109.                     tempPoints.add(p2);  
  110.                 }  
  111.             }  
  112.         }  
  113.   
  114.         // 继续递归  
  115.         for (Point p : tempPoints) {  
  116.             dfsExpand(p, length, edges);  
  117.         }  
  118.     }  
  119.   
  120.     /** 
  121.      * 判断连通图是否还需要再被划分 
  122.      *  
  123.      * @param pointList1 
  124.      *            坐标点集合1 
  125.      * @param pointList2 
  126.      *            坐标点集合2 
  127.      * @return 
  128.      */  
  129.     private boolean needDivided(ArrayList<Point> pointList1,  
  130.             ArrayList<Point> pointList2) {  
  131.         boolean needDivided = false;  
  132.         // 承受系数t=轻的集合的坐标点数/2部分连接的边数  
  133.         double t = 0;  
  134.         // 分裂阈值,即平均每边所要承受的重量  
  135.         double landa = 0;  
  136.         int pointNum1 = pointList1.size();  
  137.         int pointNum2 = pointList2.size();  
  138.         // 总边数  
  139.         int totalEdgeNum = 0;  
  140.         // 连接2部分的边数量  
  141.         int connectedEdgeNum = 0;  
  142.         ArrayList<Point> totalPoints = new ArrayList<>();  
  143.   
  144.         totalPoints.addAll(pointList1);  
  145.         totalPoints.addAll(pointList2);  
  146.         int id1 = 0;  
  147.         int id2 = 0;  
  148.         for (Point p1 : totalPoints) {  
  149.             id1 = p1.id;  
  150.             for (Point p2 : totalPoints) {  
  151.                 id2 = p2.id;  
  152.   
  153.                 if (edges[id1][id2] == 1 && id1 < id2) {  
  154.                     if ((pointList1.contains(p1) && pointList2.contains(p2))  
  155.                             || (pointList1.contains(p2) && pointList2  
  156.                                     .contains(p1))) {  
  157.                         connectedEdgeNum++;  
  158.                     }  
  159.                     totalEdgeNum++;  
  160.                 }  
  161.             }  
  162.         }  
  163.   
  164.         if (pointNum1 < pointNum2) {  
  165.             // 承受系数t=轻的集合的坐标点数/连接2部分的边数  
  166.             t = 1.0 * pointNum1 / connectedEdgeNum;  
  167.         } else {  
  168.             t = 1.0 * pointNum2 / connectedEdgeNum;  
  169.         }  
  170.   
  171.         // 计算分裂阈值,括号内为总边数/总点数,就是平均每边所承受的点数量  
  172.         landa = 0.5 * Math.exp((1.0 * totalEdgeNum / (pointNum1 + pointNum2)));  
  173.   
  174.         // 如果承受系数不小于分裂阈值,则代表需要分裂  
  175.         if (t >= landa) {  
  176.             needDivided = true;  
  177.         }  
  178.   
  179.         return needDivided;  
  180.     }  
  181.   
  182.     /** 
  183.      * 递归的划分连通图 
  184.      *  
  185.      * @param pointList 
  186.      *            待划分的连通图的所有坐标点 
  187.      */  
  188.     public void divideGraph(ArrayList<Point> pointList) {  
  189.         // 判断此坐标点集合是否能够被分割  
  190.         boolean canDivide = false;  
  191.         ArrayList<ArrayList<Point>> pointGroup;  
  192.         ArrayList<Point> pointList1 = new ArrayList<>();  
  193.         ArrayList<Point> pointList2 = new ArrayList<>();  
  194.   
  195.         for (int m = 2; m <= pointList.size() / 2; m++) {  
  196.             // 进行坐标点的分割  
  197.             pointGroup = removePoint(pointList, m);  
  198.             pointList1 = pointGroup.get(0);  
  199.             pointList2 = pointGroup.get(1);  
  200.   
  201.             // 判断是否满足分裂条件  
  202.             if (needDivided(pointList1, pointList2)) {  
  203.                 canDivide = true;  
  204.                 divideGraph(pointList1);  
  205.                 divideGraph(pointList2);  
  206.             }  
  207.         }  
  208.   
  209.         // 如果所有的分割组合都无法分割,则说明此已经是一个聚类  
  210.         if (!canDivide) {  
  211.             clusters.add(pointList);  
  212.         }  
  213.     }  
  214.   
  215.     /** 
  216.      * 获取分裂得到的聚类结果 
  217.      *  
  218.      * @return 
  219.      */  
  220.     public ArrayList<ArrayList<Point>> getClusterByDivding() {  
  221.         clusters = new ArrayList<>();  
  222.           
  223.         divideGraph(points);  
  224.   
  225.         return clusters;  
  226.     }  
  227.   
  228.     /** 
  229.      * 将当前坐标点集合移除removeNum个点,构成2个子坐标点集合 
  230.      *  
  231.      * @param pointList 
  232.      *            原集合点 
  233.      * @param removeNum 
  234.      *            移除的数量 
  235.      */  
  236.     private ArrayList<ArrayList<Point>> removePoint(ArrayList<Point> pointList,  
  237.             int removeNum) {  
  238.         //浅拷贝一份原坐标点数据  
  239.         ArrayList<Point> copyPointList = (ArrayList<Point>) pointList.clone();  
  240.         ArrayList<ArrayList<Point>> pointGroup = new ArrayList<>();  
  241.         ArrayList<Point> pointList2 = new ArrayList<>();  
  242.         // 进行按照坐标轴大小排序  
  243.         Collections.sort(copyPointList);  
  244.   
  245.         for (int i = 0; i < removeNum; i++) {  
  246.             pointList2.add(copyPointList.get(i));  
  247.         }  
  248.         copyPointList.removeAll(pointList2);  
  249.   
  250.         pointGroup.add(copyPointList);  
  251.         pointGroup.add(pointList2);  
  252.   
  253.         return pointGroup;  
  254.     }  
  255.   
  256.     /** 
  257.      * 根据边的情况获取其中的点 
  258.      *  
  259.      * @param edges 
  260.      *            当前的已知的边的情况 
  261.      * @return 
  262.      */  
  263.     private ArrayList<Point> getPointByEdges(int[][] edges) {  
  264.         Point p1;  
  265.         Point p2;  
  266.         ArrayList<Point> pointList = new ArrayList<>();  
  267.   
  268.         for (int i = 0; i < edges.length; i++) {  
  269.             for (int j = 0; j < edges[0].length; j++) {  
  270.                 if (edges[i][j] == 1) {  
  271.                     p1 = CABDDCCTool.totalPoints.get(i);  
  272.                     p2 = CABDDCCTool.totalPoints.get(j);  
  273.   
  274.                     if (!pointList.contains(p1)) {  
  275.                         pointList.add(p1);  
  276.                     }  
  277.   
  278.                     if (!pointList.contains(p2)) {  
  279.                         pointList.add(p2);  
  280.                     }  
  281.                 }  
  282.             }  
  283.         }  
  284.   
  285.         return pointList;  
  286.     }  
  287. }  
算法工具类:

[java] view plaincopyprint?
  1. package DataMining_CABDDCC;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.IOException;  
  7. import java.text.MessageFormat;  
  8. import java.util.ArrayList;  
  9.   
  10. /** 
  11.  * 基于连通图的分裂聚类算法 
  12.  *  
  13.  * @author lyq 
  14.  *  
  15.  */  
  16. public class CABDDCCTool {  
  17.     // 测试数据点数据  
  18.     private String filePath;  
  19.     // 连通图距离阈值l  
  20.     private int length;  
  21.     // 原始坐标点  
  22.     public static ArrayList<Point> totalPoints;  
  23.     // 聚类结果坐标点集合  
  24.     private ArrayList<ArrayList<Point>> resultClusters;  
  25.     // 连通图  
  26.     private Graph graph;  
  27.   
  28.     public CABDDCCTool(String filePath, int length) {  
  29.         this.filePath = filePath;  
  30.         this.length = length;  
  31.   
  32.         readDataFile();  
  33.     }  
  34.   
  35.     /** 
  36.      * 从文件中读取数据 
  37.      */  
  38.     public void readDataFile() {  
  39.         File file = new File(filePath);  
  40.         ArrayList<String[]> dataArray = new ArrayList<String[]>();  
  41.   
  42.         try {  
  43.             BufferedReader in = new BufferedReader(new FileReader(file));  
  44.             String str;  
  45.             String[] tempArray;  
  46.             while ((str = in.readLine()) != null) {  
  47.                 tempArray = str.split(" ");  
  48.                 dataArray.add(tempArray);  
  49.             }  
  50.             in.close();  
  51.         } catch (IOException e) {  
  52.             e.getStackTrace();  
  53.         }  
  54.   
  55.         Point p;  
  56.         totalPoints = new ArrayList<>();  
  57.         for (String[] array : dataArray) {  
  58.             p = new Point(array[0], array[1], array[2]);  
  59.             totalPoints.add(p);  
  60.         }  
  61.   
  62.         // 用边和点构造图  
  63.         graph = new Graph(null, totalPoints);  
  64.     }  
  65.   
  66.     /** 
  67.      * 分裂连通图得到聚类 
  68.      */  
  69.     public void splitCluster() {  
  70.         // 获取形成连通子图  
  71.         ArrayList<Graph> subGraphs;  
  72.         ArrayList<ArrayList<Point>> pointList;  
  73.         resultClusters = new ArrayList<>();  
  74.   
  75.         subGraphs = graph.splitGraphByLength(length);  
  76.   
  77.         for (Graph g : subGraphs) {  
  78.             // 获取每个连通子图分裂后的聚类结果  
  79.             pointList = g.getClusterByDivding();  
  80.             resultClusters.addAll(pointList);  
  81.         }  
  82.           
  83.         printResultCluster();  
  84.     }  
  85.   
  86.     /** 
  87.      * 输出结果聚簇 
  88.      */  
  89.     private void printResultCluster() {  
  90.         int i = 1;  
  91.         for (ArrayList<Point> cluster : resultClusters) {  
  92.             System.out.print("聚簇" + i + ":");  
  93.             for (Point p : cluster){  
  94.                 System.out.print(MessageFormat.format("({0}, {1}) ", p.x, p.y));  
  95.             }  
  96.             System.out.println();  
  97.             i++;  
  98.         }  
  99.           
  100.     }  
  101.   
  102. }  
算法调用类Client.java:

[java] view plaincopyprint?
  1. package DataMining_CABDDCC;  
  2.   
  3. /** 
  4.  * 基于连通图的分裂聚类算法 
  5.  * @author lyq 
  6.  * 
  7.  */  
  8. public class Client {  
  9.     public static void main(String[] agrs){  
  10.         String filePath = "C:\\Users\\lyq\\Desktop\\icon\\graphData.txt";  
  11.         //连通距离阈值  
  12.         int length = 3;  
  13.           
  14.         CABDDCCTool tool = new CABDDCCTool(filePath, length);  
  15.         tool.splitCluster();  
  16.     }  
  17. }  

算法的输出:

[java] view plaincopyprint?
  1. 聚簇1:(69) (810) (911) (109) (1112)   
  2. 聚簇2:(112) (39) (312) (410)   
  3. 聚簇3:(44) (41) (63) (61) (83) (92)   

图形化的展示结果如下,一张是连通图的有效边(就是e[i][j]=1)的情况,后张图是分裂的聚类结果:



图片有点大,就没有处理了,大家将就着看吧.....

算法的遗漏点和优点

其实这个算法我在实现的时候,其实少考虑了很多东西,首先一个是构造连通图的时候,可以从示例的图线中看出,最后的图应该是一个闭环图,而我通过类似于DBSCAN算法会导致最边界的点会暴露在外面,形成不了闭环,与题目所要求的会有点不符。还有1点是划分部分坐标点的时候,我默认是从左往右,从下往上的优先级的顺序进行划分,但是我觉得更加合理的方式应该是怎样的。还有1个算法的缺点是总是在不停的比较中,时间开销比较大。算法非常的新颖,用了图的思想去做聚类的实现,而且用了类似于扁担挑重物的原理运用到数据挖掘中,不愧是一篇好论文。像我目前就只能是站在巨人的肩膀上,做点小东西罢了....

0 0