Kmeans聚类算法 java精简版设计实现编程

来源:互联网 发布:淘宝 麻辣小黑粉 编辑:程序博客网 时间:2024/05/22 12:03

网上有许多Kmeans写的java算法,当然依据个人编码风格的不同,导致编写出来的代码,各有不同。所以在理解原理的基础上,最好就是按照自己设计思路将代码自己写出来。

度娘搜Kmeans的基本原理吧,直接上代码,代码中都有注释:

package net.codeal.suanfa.kmeans;import java.util.Set;/** *  * @ClassName: Distancable  * @Description: TODO(可计算两点之间距离的可中心化的父类)  * @author fuhuaguo * @date 2015年9月1日 上午11:41:23  * */public class Kmeansable<E> {/** * 获取两点之间的距离 * @param other * @return */public double getDistance(E other){return 0;}/** * 获取新的中心点 * @param eSet * @return */public E getNewCenter(Set<E> eSet){return null;}}

package net.codeal.suanfa.kmeans;import java.util.Set;/** *  * @ClassName: Point  * @Description: TODO(聚类的维度信息bean,可以分为K个维度,相似度计算是自身行为,放在bean内部才合适,取消注解使用)  * @author fuhuaguo * @email fhg@jusfoun.com * @date 2015年9月1日 上午10:43:25  * */public class Point extends Kmeansable<Point>{private String id;//维度1private double k1;//维度2private double k2;//维度3private double k3;public Point() {}public Point(String id,double k1,double k2,double k3) {this.id = id;this.k1 = k1;this.k2 = k2;this.k3 = k3;}/** * 计算和另一个点的距离,采用欧几里得算法 ,计算维度算数平方和的sqrt值,即:相异度 * @param other * @return */@Overridepublic double getDistance(Point other){return Math.sqrt((this.k1-other.getK1())*(this.k1-other.getK1())+ (this.k2-other.getK2())*(this.k2-other.getK2())+ (this.k3-other.getK3())*(this.k3-other.getK3()));}@Overridepublic Point getNewCenter(Set<Point> eSet) {if(eSet == null || eSet.size() == 0){return this;}Point temp = new Point();int count = 0;for (Point p : eSet) {temp.setK1(temp.getK1() + p.getK1());temp.setK2(temp.getK2() + p.getK2());temp.setK3(temp.getK3() + p.getK3());count++;}temp.setK1(temp.getK1()/count);temp.setK2(temp.getK2()/count);temp.setK3(temp.getK3()/count);return temp;}@Overridepublic boolean equals(Object obj) {if(obj == null || !(obj instanceof Point))return false;Point other = (Point) obj;return (this.k1 == other.getK1()) && (this.k2 == other.getK2()) && (this.k3 == other.getK3());}@Overridepublic int hashCode() {return new Double(k1+k2+k3).hashCode();}@Overridepublic String toString() {return "("+k1+","+k2+","+k3+")";} public String getId() {return id;}public void setId(String id) {this.id = id;}public double getK1() {return k1;}public void setK1(double k1) {this.k1 = k1;}public double getK2() {return k2;}public void setK2(double k2) {this.k2 = k2;}public double getK3() {return k3;}public void setK3(double k3) {this.k3 = k3;}}

package net.codeal.suanfa.kmeans;import java.util.HashMap;import java.util.HashSet;import java.util.Map;import java.util.Set;public class KmeansAlgorithm<E extends Kmeansable<E>> {/** * 对Set进行K个值聚类,计算深度最大为depth */public void kmeans(Set<E> dataSet, int k, int depth){//分类数设置不合适if(k <= 1 || dataSet.size() <= k){return;}Set<E> kSet = new HashSet<E>();int count = 0;//随机确定K个中心点for (E e : dataSet) {if(count++ >= k)break;kSet.add(e);}//计算每个值距离各个中心点的距离,分配到距离最小的那个中心上boolean flag = true;while(flag && depth > 0){Map<E, Set<E>> kMap = new HashMap<E, Set<E>>();for (E e : kSet) {kMap.put(e, new HashSet<E>());}//完成聚类for (E data : dataSet) {double d = Double.MAX_VALUE;E e = null;for (E center : kSet) {double d1 = data.getDistance(center);if (d > d1){e = center;d = d1;}}kMap.get(e).add(data);}//第一组计算完毕,同时获取新的中心点System.out.println("这是第"+depth+"次聚类");for (Map.Entry<E, Set<E>> m : kMap.entrySet()) {System.out.println(m.getKey()+":"+m.getValue());}//获取新的聚类中心Set<E> oldSet = kSet;kSet = getNewCenters(kMap);flag = !isSameCenters(kSet,oldSet);depth--;}}/** * 获取新的中心点 列表 */public Set<E> getNewCenters(Map<E, Set<E>> kMap){Set<E> eSet = new HashSet<E>();for (Map.Entry<E, Set<E>> m : kMap.entrySet()) {eSet.add(m.getKey().getNewCenter(m.getValue()));}return eSet;}/** * 判断是否为同一个中心列表 */public boolean isSameCenters(Set<E> oldSet,Set<E> newSet){//两个集合只要交集为0就是相同的return oldSet.containsAll(newSet);}public static void main(String[] args) {Set<Point> dataSet = new HashSet<Point>();dataSet.add(new Point("1",1,1,1));dataSet.add(new Point("1",2,2,2));dataSet.add(new Point("1",5,6,1));dataSet.add(new Point("1",10,10,10));dataSet.add(new Point("1",11,11,11));new KmeansAlgorithm<Point>().kmeans(dataSet, 2,10);}}
结果:

这是第10次聚类
(1.0,1.0,1.0):[(1.0,1.0,1.0), (2.0,2.0,2.0), (5.0,6.0,1.0)]
(10.0,10.0,10.0):[(10.0,10.0,10.0), (11.0,11.0,11.0)]
这是第9次聚类
(10.5,10.5,10.5):[(10.0,10.0,10.0), (11.0,11.0,11.0)]
(2.6666666666666665,3.0,1.3333333333333333):[(1.0,1.0,1.0), (2.0,2.0,2.0), (5.0,6.0,1.0)]




1 0