Java实现聚类分析Kmeans算法

来源:互联网 发布:成熟电子病历系统源码 编辑:程序博客网 时间:2024/05/17 22:40


package com.wgw.util;


/**
 * 作用:将一组无序数据自动归类
 *均值聚类算法
 * @author wanggw
 */
public class KMeans {


    static int k = 4;
    static double value = 0.01;
    public static double[] newk;


    public static void main(String[] args) {
        // TODO Auto-generated method stub
        double[] data = new double[]{-2, 1.3, 1.4, 8, 1.4, 1.5, -10, 45, 5, 6,600,48,30,300};
        //指定要计算的数据最终分为几组
        double[][] means = cluster(data, 4);
        print(means);
//        System.out.println(KMeans.findCenter(data));


    }


    /**
     * 聚类入口
     *
     * @param data
     * @param k
     * @return
     */
    public static double[][] cluster(double[] data, int k) {
        double[] initk = new double[k];
        double[][] result = null;
        //聚点初始化,默认以data集的前k个为初始聚点
        for (int i = 0; i < k; i++) {
            initk[i] = data[i];
        }
        //计算每个点到各聚点的距离,并归类到距离最近的点
        result = query(data, initk);
        return result;
    }


    private static void print(double[][] d) {
        for (int i = 0; i < d.length; i++) {
            for (int j = 0; j < d[i].length; j++) {
                System.out.print(d[i][j] + "\t");
            }
            System.out.println("");
        }
    }


    /**
     * 循环计算聚点直到不发生变化或者到达阀值
     *
     * @param data
     * @param k
     * @return
     */
    private static double[][] query(double[] data, double[] k) {


        double[][] group = null;
        boolean flag = true;
        double[] newk = null;
        while (flag) {
            group = new double[k.length][0];
            for (int i = 0; i < data.length; i++) {
                double[] results = new double[k.length];
                for (int j = 0; j < k.length; j++) {
                    results[j] = length(data[i], k[j]);
                }
                //归类到距离最近的聚点


                group = split(results, data[i], group);


            }
            //分类完毕,计算新的聚点
            newk = findk(group);
            if (!equal(k, newk)) {
                k = newk;
            } else {
                break;
            }
        }
        return group;
    }


    private static boolean equal(double[] oldk, double[] newk) {
        for (int i = 0; i < oldk.length; i++) {
            if (oldk[i] != newk[i] || Math.abs(oldk[i] - newk[i]) > value) {
                return false;
            }
        }
        KMeans.newk = newk;
        return true;
    }


    private static double[] findk(double[][] group) {
        double[] newk = new double[group.length];
        for (int i = 0; i < group.length; i++) {
            double[] data = group[i];
            double sum = 0;
            for (int j = 0; j < data.length; j++) {
                sum = sum + data[j];
            }
            newk[i] = sum / data.length;


        }
        return newk;


    }


    /**
     * 计算各点到其它点的距离之和,返回最小距离之和的点
     *
     * @param data
     * @return
     */
    public static double findCenter(double[] data) {
        double min = Double.MAX_VALUE;
        int index = 0;
        if (data != null && data.length > 1) {
            for (int i = 0; i < data.length; i++) {
                double nextMin = 0;
                for (int j = 0; j < data.length; j++) {
                    nextMin += Math.abs(data[i] - data[j]);
                }
                if (min > nextMin) {
                    index = i;
                }
            }
            return data[index];
        }else{
            return 0.0;
        }


    }


    private static double length(double a, double b) {
        //以点之间的距离计算,多维的数据可以扩展
        double n = Math.abs(a - b);
        return n;
    }


    /**
     *
     * @param d 距离数组
     * @param km 分类数组
     * @return
     */
    private static double[][] split(double[] d, double m, double[][] km) {
        //记录最小距离
        double min = d[0];
        //记录最小距离的下标
        int k = 0;
        for (int i = 1; i < d.length; i++) {
            if (min > d[i]) {
                min = d[i];
                k = i;
            }
        }
        km[k] = teamAdd(km[k], m);
        return km;


    }


    public static double[] teamAdd(double[] d, double m) {
        if (d.length == 0) {
            return new double[]{m};
        }
        double[] newd = new double[d.length + 1];
        for (int i = 0; i < d.length; i++) {
            newd[i] = d[i];
        }
        newd[d.length] = m;
        return newd;
    }
}
0 0
原创粉丝点击