基于Java实现K-means算法

来源:互联网 发布:淘宝同步电讯 编辑:程序博客网 时间:2024/05/21 22:23

1. K-means算法介绍

先放着,日后补上!

2. java实现代码

代码说明:
1. 测试数据的输入维度必须都相等才能运行!
    [1,2,3,3]和[2,3,4,4] 这样维度相同,都是4维,可以运行。
    [1,2,3,3]和[2,3,4] 这样维度不相同,不能运行。
2. 输入数据的点,尽量不要重复!
    输入的数据点[1,2,3]和[1,2,3] 这样的数据点是重复的。可以运行,但是性能会下降。

2.1 分装Point类

public class Point {    private float[] localArray;    private int id;    private int clusterId;  // 标识属于哪个类中心。    private float dist;     // 标识和所属类中心的距离。    public Point(int id, float[] localArray) {        this.id = id;        this.localArray = localArray;    }    public Point(float[] localArray) {        this.id = -1; //表示不属于任意一个类        this.localArray = localArray;    }    public float[] getlocalArray() {        return localArray;    }    public int getId() {        return id;    }    public void setClusterId(int clusterId) {        this.clusterId = clusterId;    }    public int getClusterid() {        return clusterId;    }    public float getDist() {        return dist;    }    public void setDist(float dist) {        this.dist = dist;    }    @Override    public String toString() {        String result = "Point_id=" + id + "  [";        for (int i = 0; i < localArray.length; i++) {            result += localArray[i] + " ";        }        return result.trim()+"] clusterId: "+clusterId+" dist: "+dist;    }    @Override    public boolean equals(Object obj) {        if (obj == null || getClass() != obj.getClass())            return false;        Point point = (Point) obj;        if (point.localArray.length != localArray.length)            return false;        for (int i = 0; i < localArray.length; i++) {            if (Float.compare(point.localArray[i], localArray[i]) != 0) {                return false;            }        }        return true;    }    @Override    public int hashCode() {        float x = localArray[0];        float y = localArray[localArray.length - 1];        long temp = x != +0.0d ? Double.doubleToLongBits(x) : 0L;        int result = (int) (temp ^ (temp >>> 32));        temp = y != +0.0d ? Double.doubleToLongBits(y) : 0L;        result = 31 * result + (int) (temp ^ (temp >>> 32));        return result;    }}

2.2 封装Cluster

import java.util.ArrayList;import java.util.List;public class Cluster {    private int id;// 标识    private Point center;// 中心    private List<Point> members = new ArrayList<Point>();// 成员    public Cluster(int id, Point center) {        this.id = id;        this.center = center;    }    public Cluster(int id, Point center, List<Point> members) {        this.id = id;        this.center = center;        this.members = members;    }    public void addPoint(Point newPoint) {        if (!members.contains(newPoint)){            members.add(newPoint);        }else{            System.out.println("样本数据点 {"+newPoint.toString()+"} 已经存在!");        }    }    public int getId() {        return id;    }    public Point getCenter() {        return center;    }    public void setCenter(Point center) {        this.center = center;    }    public List<Point> getMembers() {        return members;    }    @Override    public String toString() {        String toString = "Cluster \n" + "Cluster_id=" + this.id + ", center:{" + this.center.toString()+"}";        for (Point point : members) {            toString+="\n"+point.toString();        }        return toString+"\n";    }}

2.3 求两个点的欧式距离

public class DistanceCompute {    /**     * 求欧式距离     */    public double getEuclideanDis(Point p1, Point p2) {        double count_dis = 0;        float[] p1_local_array = p1.getlocalArray();        float[] p2_local_array = p2.getlocalArray();        if (p1_local_array.length != p2_local_array.length) {            throw new IllegalArgumentException("length of array must be equal!");        }        for (int i = 0; i < p1_local_array.length; i++) {            count_dis += Math.pow(p1_local_array[i] - p2_local_array[i], 2);        }        return Math.sqrt(count_dis);    }}

2.4 核心运行类

import java.util.ArrayList;import java.util.HashSet;import java.util.List;import java.util.Random;import java.util.Set;public class KMeansRun {      private int kNum;                             //簇的个数    private int iterNum = 10;                     //迭代次数    private int iterMaxTimes = 100000;            //单次迭代最大运行次数    private int iterRunTimes = 0;                 //单次迭代实际运行次数    private float disDiff = (float) 0.01;         //单次迭代终止条件,两次运行中类中心的距离差    private List<float[]> original_data =null;    //用于存放,原始数据集      private static List<Point> pointList = null;  //用于存放,原始数据集所构建的点集    private DistanceCompute disC = new DistanceCompute();    private int len = 0;                          //用于记录每个数据点的维度    public KMeansRun(int k, List<float[]> original_data) {        this.kNum = k;        this.original_data = original_data;        this.len = original_data.get(0).length;         //检查规范        check();        //初始化点集。        init();    }    /**     * 检查规范     */    private void check() {        if (kNum == 0){            throw new IllegalArgumentException("k must be the number > 0");          }        if (original_data == null){            throw new IllegalArgumentException("program can't get real data");        }    }     /**      * 初始化数据集,把数组转化为Point类型。     */    private void init() {        pointList = new ArrayList<Point>();        for (int i = 0, j = original_data.size(); i < j; i++){            pointList.add(new Point(i, original_data.get(i)));        }    }    /**      * 随机选取中心点,构建成中心类。     */      private Set<Cluster> chooseCenterCluster() {        Set<Cluster> clusterSet = new HashSet<Cluster>();        Random random = new Random();        for (int id = 0; id < kNum; ) {            Point point = pointList.get(random.nextInt(pointList.size()));            // 用于标记是否已经选择过该数据。            boolean flag =true;            for (Cluster cluster : clusterSet) {                if (cluster.getCenter().equals(point)) {                    flag = false;                }            }            // 如果随机选取的点没有被选中过,则生成一个cluster            if (flag) {                Cluster cluster =new Cluster(id, point);                clusterSet.add(cluster);                id++;            }        }        return clusterSet;      }    /**     * 为每个点分配一个类!     */    public void cluster(Set<Cluster> clusterSet){        // 计算每个点到K个中心的距离,并且为每个点标记类别号        for (Point point : pointList) {            float min_dis = Integer.MAX_VALUE;            for (Cluster cluster : clusterSet) {                float tmp_dis = (float) Math.min(disC.getEuclideanDis(point, cluster.getCenter()), min_dis);                if (tmp_dis != min_dis) {                    min_dis = tmp_dis;                    point.setClusterId(cluster.getId());                    point.setDist(min_dis);                }            }        }        // 新清除原来所有的类中成员。把所有的点,分别加入每个类别        for (Cluster cluster : clusterSet) {            cluster.getMembers().clear();            for (Point point : pointList) {                if (point.getClusterid()==cluster.getId()) {                    cluster.addPoint(point);                }            }        }    }    /**     * 计算每个类的中心位置!     */    public boolean calculateCenter(Set<Cluster> clusterSet) {        boolean ifNeedIter = false;         for (Cluster cluster : clusterSet) {            List<Point> point_list = cluster.getMembers();            float[] sumAll =new float[len];            // 所有点,对应各个维度进行求和            for (int i = 0; i < len; i++) {                for (int j = 0; j < point_list.size(); j++) {                    sumAll[i] += point_list.get(j).getlocalArray()[i];                }            }            // 计算平均值            for (int i = 0; i < sumAll.length; i++) {                sumAll[i] = (float) sumAll[i]/point_list.size();            }            // 计算两个新、旧中心的距离,如果任意一个类中心移动的距离大于dis_diff则继续迭代。            if(disC.getEuclideanDis(cluster.getCenter(), new Point(sumAll)) > disDiff){                ifNeedIter = true;            }            // 设置新的类中心位置            cluster.setCenter(new Point(sumAll));        }        return ifNeedIter;    }    /**     * 运行 k-means     */    public Set<Cluster> run() {        Set<Cluster> clusterSet= chooseCenterCluster();        boolean ifNeedIter = true;         while (ifNeedIter) {            cluster(clusterSet);            ifNeedIter = calculateCenter(clusterSet);            iterRunTimes ++ ;        }        return clusterSet;    }    /**     * 返回实际运行次数     */    public int getIterTimes() {        return iterRunTimes;    }}

2.5 测试类

import java.util.ArrayList;import java.util.Set;public class Main {    public static void main(String[] args) {        ArrayList<float[]> dataSet = new ArrayList<float[]>();        dataSet.add(new float[] { 1, 2, 3 });        dataSet.add(new float[] { 3, 3, 3 });        dataSet.add(new float[] { 3, 4, 4});        dataSet.add(new float[] { 5, 6, 5});        dataSet.add(new float[] { 8, 9, 6});        dataSet.add(new float[] { 4, 5, 4});        dataSet.add(new float[] { 6, 4, 2});        dataSet.add(new float[] { 3, 9, 7});        dataSet.add(new float[] { 5, 9, 8});        dataSet.add(new float[] { 4, 2, 10});        dataSet.add(new float[] { 1, 9, 12});        dataSet.add(new float[] { 7, 8, 112});        dataSet.add(new float[] { 7, 8, 4});        KMeansRun kRun =new KMeansRun(3, dataSet);        Set<Cluster> clusterSet = kRun.run();        System.out.println("单次迭代运行次数:"+kRun.getIterTimes());        for (Cluster cluster : clusterSet) {            System.out.println(cluster);        }    }}

3. 关于输入数据和运行结果

这里故意加入了一个异常点[7 , 8 ,112]进行测试,果然独自分为一类:
这里写图片描述

0 0
原创粉丝点击