Kmeans聚类算法-二维度数组(Java实现)

来源:互联网 发布:迪杰斯特拉算法 编辑:程序博客网 时间:2024/05/01 18:46

本文章转载至:http://blog.csdn.net/cyxlzzs/article/details/7416491

源码

Kmeans.java文件源码如下:

package com.bigdata.ml.cluster;import java.util.ArrayList;import java.util.Random;/** * 聚类算法通常用于数据挖掘,将相似的数组进行聚簇 *  * @author zouzhongfan * */public class Kmeans {private int k;// 分成多少簇private int m;// 迭代次数private int dataSetLength;// 数据集元素个数,即数据集的长度private ArrayList<float[]> dataSet;// 数据集链表private ArrayList<float[]> center;// 中心链表private ArrayList<ArrayList<float[]>> cluster; // 簇private ArrayList<Float> jc;// 误差平方和,k越接近dataSetLength,误差越小private Random random;/** * 设置需分组的原始数据集 *  * @param dataSet */public void setDataSet(ArrayList<float[]> dataSet) {this.dataSet = dataSet;}/** * 获取结果分组 *  * @return 结果集 */public ArrayList<ArrayList<float[]>> getCluster() {return cluster;}/** * 构造函数,传入需要分成的簇数量 *  * @param k *            ,簇数量,若k<=0时,设置为1,若k大于数据源的长度时,置为数据源的长度 */public Kmeans(int k) {if (k <= 0) {k = 1;}this.k = k;}/** * 初始化 */private void init() {m = 0;random = new Random();if (dataSet == null || dataSet.size() == 0) {initDataSet();}dataSetLength = dataSet.size();// 若k大于数据源的长度时,置为数据源的长度if (k > dataSetLength) {k = dataSetLength;}center = initCenters();// 初始化中心cluster = initCluster();// 初始化簇集,分配内存,但元素为空jc = new ArrayList<Float>();// 初始化误差平方和}/** * 如果调用者未初始化数据集,则采用内部测试数据集 */private void initDataSet() {dataSet = new ArrayList<float[]>();// 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 },{ 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },{ 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } };for (int i = 0; i < dataSetArray.length; i++) {dataSet.add(dataSetArray[i]);}}/** * 初始化中心数据链表,分成多少簇就有多少个中心点 *  * @return 中心点集 */private ArrayList<float[]> initCenters() {ArrayList<float[]> center = new ArrayList<float[]>();int[] randoms = new int[k];boolean flag;// 生成k个互补相同的随机数int temp = random.nextInt(dataSetLength);randoms[0] = temp;for (int i = 1; i < k; i++) {flag = true;while (flag) {temp = random.nextInt(dataSetLength);int j = 0;while (j < i) {if (temp == randoms[j]) {break;}j++;}if (j == i) {flag = false;}}randoms[i] = temp;}// 生成初始化中心链表for (int i = 0; i < k; i++) {center.add(dataSet.get(randoms[i]));}return center;}/** * 初始化簇集合 *  * @return 一个分为k簇的空数据的簇集合 */private ArrayList<ArrayList<float[]>> initCluster() {ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>();for (int i = 0; i < k; i++) {cluster.add(new ArrayList<float[]>());}return cluster;}/** * 计算两个点之间的距离(欧几里得距离) *  * @param element *            点1 * @param center *            点2 * @return 距离 */private float distance(float[] element, float[] center) {float distance = 0.0f;float x = element[0] - center[0];float y = element[1] - center[1];float z = x * x + y * y;distance = (float) Math.sqrt(z);return distance;}/** * 获取距离集合中最小距离的位置 *  * @param distance *            距离数组 * @return 最小距离在距离数组中的位置 */private int minDistance(float[] distance) {float minDistance = distance[0];int minLocation = 0;for (int i = 1; i < distance.length; i++) {if (distance[i] < minDistance) {minDistance = distance[i];minLocation = i;} else if (distance[i] == minDistance) // 如果相等,随机返回一个位置{if (random.nextInt(10) < 5) {minLocation = i;}}}return minLocation;}/** * 核心 计算两点之间的距离,并将当前元素放到最小距离中心的簇中 */private void clusterSet() {float[] distance = new float[k];for (int i = 0; i < dataSetLength; i++) {for (int j = 0; j < k; j++) {distance[j] = distance(dataSet.get(i), center.get(j));// 计算两个点之间的距离}int minLocation = minDistance(distance);cluster.get(minLocation).add(dataSet.get(i));// 核心,将当前元素放到最小距离中心的簇中}}/** * 求两点误差平方的方法 *  * @param element *            点1 * @param center *            点2 * @return 误差平方 */private float errorSquare(float[] element, float[] center) {float x = element[0] - center[0];float y = element[1] - center[1];float errSquare = x * x + y * y;return errSquare;}/** * 计算误差平方和准则函数方法 */private void countRule() {float jcF = 0;for (int i = 0; i < cluster.size(); i++) {for (int j = 0; j < cluster.get(i).size(); j++) {jcF += errorSquare(cluster.get(i).get(j), center.get(i));}}jc.add(jcF);}/** * 设置新的簇中心方法 */private void setNewCenter() {for (int i = 0; i < k; i++) {int n = cluster.get(i).size();if (n != 0) {float[] newCenter = { 0, 0 };for (int j = 0; j < n; j++) {newCenter[0] += cluster.get(i).get(j)[0];newCenter[1] += cluster.get(i).get(j)[1];}// 设置一个平均值newCenter[0] = newCenter[0] / n;newCenter[1] = newCenter[1] / n;center.set(i, newCenter);}}}/** * 打印数据,测试用 *  * @param dataArray *            数据集 * @param dataArrayName *            数据集名称 */public void printDataArray(ArrayList<float[]> dataArray,String dataArrayName) {for (int i = 0; i < dataArray.size(); i++) {System.out.println("print:" + dataArrayName + "[" + i + "]={"+ dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");}System.out.println("===================================");}/** * Kmeans算法核心过程方法 */private void kmeans() {init();// 初始化printDataArray(dataSet, "initDataSet"); // 输出初始化数据集printDataArray(center, "initCenter"); // 输出初始化中心// 循环分组,直到误差不变为止while (true) {clusterSet(); // 生成簇集元素// 输出簇集生成结果for (int i = 0; i < cluster.size(); i++) {printDataArray(cluster.get(i), "cluster[" + i + "]");}countRule();// 计算误差平方和System.out.println("count:" + "jc[" + m + "]=" + jc.get(m));System.out.println();// 判断退出迭代条件,当最近两次的误差平方和相等,则退出迭代。if (m != 0) {if (jc.get(m) - jc.get(m - 1) == 0) {break;}}setNewCenter();// 计算新的中心printDataArray(center, "newCenter");// 输出新的中心m++;cluster.clear(); // 簇集清空cluster = initCluster(); // 簇集初始化}System.out.println("note:the times of repeat:m=" + m);// 输出迭代次数}/** * 执行算法 */public void execute() {long startTime = System.currentTimeMillis();System.out.println("kmeans begins");kmeans();long endTime = System.currentTimeMillis();System.out.println("kmeans running time=" + (endTime - startTime)+ "ms");System.out.println("kmeans ends");System.out.println();}public static void main(String[] args) {// 初始化一个Kmean对象,将k置为3Kmeans k = new Kmeans(3);ArrayList<float[]> dataSet = new ArrayList<float[]>();dataSet.add(new float[] { 1, 2 });dataSet.add(new float[] { 3, 3 });dataSet.add(new float[] { 3, 4 });dataSet.add(new float[] { 5, 6 });dataSet.add(new float[] { 8, 9 });dataSet.add(new float[] { 4, 5 });dataSet.add(new float[] { 6, 4 });dataSet.add(new float[] { 3, 9 });dataSet.add(new float[] { 5, 9 });dataSet.add(new float[] { 4, 2 });dataSet.add(new float[] { 1, 9 });dataSet.add(new float[] { 7, 8 });// 设置原始数据集k.setDataSet(dataSet);// 执行算法k.execute();// 得到聚类结果ArrayList<ArrayList<float[]>> cluster = k.getCluster();// 查看结果for (int i = 0; i < cluster.size(); i++) {k.printDataArray(cluster.get(i), "cluster[" + i + "]");}}}


测试

测试结果如下:

kmeans begins
print:initDataSet[0]={1.0,2.0}
print:initDataSet[1]={3.0,3.0}
print:initDataSet[2]={3.0,4.0}
print:initDataSet[3]={5.0,6.0}
print:initDataSet[4]={8.0,9.0}
print:initDataSet[5]={4.0,5.0}
print:initDataSet[6]={6.0,4.0}
print:initDataSet[7]={3.0,9.0}
print:initDataSet[8]={5.0,9.0}
print:initDataSet[9]={4.0,2.0}
print:initDataSet[10]={1.0,9.0}
print:initDataSet[11]={7.0,8.0}
===================================
print:initCenter[0]={3.0,9.0}
print:initCenter[1]={4.0,5.0}
print:initCenter[2]={1.0,9.0}
===================================
print:cluster[0][0]={8.0,9.0}
print:cluster[0][1]={3.0,9.0}
print:cluster[0][2]={5.0,9.0}
print:cluster[0][3]={7.0,8.0}
===================================
print:cluster[1][0]={1.0,2.0}
print:cluster[1][1]={3.0,3.0}
print:cluster[1][2]={3.0,4.0}
print:cluster[1][3]={5.0,6.0}
print:cluster[1][4]={4.0,5.0}
print:cluster[1][5]={6.0,4.0}
print:cluster[1][6]={4.0,2.0}
===================================
print:cluster[2][0]={1.0,9.0}
===================================
count:jc[0]=87.0


print:newCenter[0]={5.75,8.75}
print:newCenter[1]={3.7142856,3.7142856}
print:newCenter[2]={1.0,9.0}
===================================
print:cluster[0][0]={8.0,9.0}
print:cluster[0][1]={5.0,9.0}
print:cluster[0][2]={7.0,8.0}
===================================
print:cluster[1][0]={1.0,2.0}
print:cluster[1][1]={3.0,3.0}
print:cluster[1][2]={3.0,4.0}
print:cluster[1][3]={5.0,6.0}
print:cluster[1][4]={4.0,5.0}
print:cluster[1][5]={6.0,4.0}
print:cluster[1][6]={4.0,2.0}
===================================
print:cluster[2][0]={3.0,9.0}
print:cluster[2][1]={1.0,9.0}
===================================
count:jc[1]=40.732143


print:newCenter[0]={6.6666665,8.666667}
print:newCenter[1]={3.7142856,3.7142856}
print:newCenter[2]={2.0,9.0}
===================================
print:cluster[0][0]={8.0,9.0}
print:cluster[0][1]={5.0,9.0}
print:cluster[0][2]={7.0,8.0}
===================================
print:cluster[1][0]={1.0,2.0}
print:cluster[1][1]={3.0,3.0}
print:cluster[1][2]={3.0,4.0}
print:cluster[1][3]={5.0,6.0}
print:cluster[1][4]={4.0,5.0}
print:cluster[1][5]={6.0,4.0}
print:cluster[1][6]={4.0,2.0}
===================================
print:cluster[2][0]={3.0,9.0}
print:cluster[2][1]={1.0,9.0}
===================================
count:jc[2]=36.190475


print:newCenter[0]={6.6666665,8.666667}
print:newCenter[1]={3.7142856,3.7142856}
print:newCenter[2]={2.0,9.0}
===================================
print:cluster[0][0]={8.0,9.0}
print:cluster[0][1]={5.0,9.0}
print:cluster[0][2]={7.0,8.0}
===================================
print:cluster[1][0]={1.0,2.0}
print:cluster[1][1]={3.0,3.0}
print:cluster[1][2]={3.0,4.0}
print:cluster[1][3]={5.0,6.0}
print:cluster[1][4]={4.0,5.0}
print:cluster[1][5]={6.0,4.0}
print:cluster[1][6]={4.0,2.0}
===================================
print:cluster[2][0]={3.0,9.0}
print:cluster[2][1]={1.0,9.0}
===================================
count:jc[3]=36.190475


note:the times of repeat:m=3
kmeans running time=5ms
kmeans ends


print:cluster[0][0]={8.0,9.0}
print:cluster[0][1]={5.0,9.0}
print:cluster[0][2]={7.0,8.0}
===================================
print:cluster[1][0]={1.0,2.0}
print:cluster[1][1]={3.0,3.0}
print:cluster[1][2]={3.0,4.0}
print:cluster[1][3]={5.0,6.0}
print:cluster[1][4]={4.0,5.0}
print:cluster[1][5]={6.0,4.0}
print:cluster[1][6]={4.0,2.0}
===================================
print:cluster[2][0]={3.0,9.0}
print:cluster[2][1]={1.0,9.0}
===================================



0 0
原创粉丝点击