KMeans on Spark

来源:互联网 发布:java 异步请求http 编辑:程序博客网 时间:2024/06/05 00:34

思路:

1.随机生成数据

2.随机生成K个聚类中心

3.计算每个点所属的类别

4.计算新的聚类中心

5.比较聚类中心的变化情况,大于阈值跳转至3;小于阈值停止。

package myclassimport java.util.Randomimport org.apache.spark.SparkContextimport SparkContext._import org.apache.spark.util.Vector/** * Created by jack on 2/26/14. */object MyKMeans {val N = 1000val R = 1000     //随机数范围  0-1  *  Rval D = 10       //点空间纬度val K = 10       //聚类中心个数val rand = new Random(42) //随机种子val convergeDist = 0.01   //迭代收敛条件/** * 将p分配到当前所有聚类中心的最短距离的类中 * */def closestPoint(p:Vector,centers: Array[Vector]): Int = {var bestIndex = 0var closest = Double.PositiveInfinityfor (i <- 0 until centers.length) {val tempDist = p.squaredDist(centers(i))if(tempDist < closest) {closest = tempDistbestIndex = i}}bestIndex}/** * 产生N个D维(每一维取值0-1000)随机的点 * */def generateData = {def generatePoint(i: Int) = {Vector(D,_ => rand.nextDouble * R)}Array.tabulate(N)(generatePoint)}def main(args: Array[String]) {val sc = new SparkContext("local","My KMeans",System.getenv("SPARK_HOME"),SparkContext.jarOfClass(this.getClass))val data = sc.parallelize(generateData).cache()//随机初始化K个聚类中心val kPoints = data.takeSample(false,K,42).toArrayvar tempDist = 1.0while(tempDist > convergeDist) {//closest为(类别,(点,1)),1是用来后续统计各个类中点的数量countval closest = data.map(p => (closestPoint(p,kPoints),(p,1)))//按类别,计算点的坐标和,以及该类别中节点总数 (类别,(点向量和,点数))val pointStats = closest.reduceByKey{case ((x1,y1),(x2,y2)) => (x1+x2,y1+y2)}//生成新的聚类中心的Map(类别,新聚类中心)val newPoints = pointStats.map{pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap()tempDist = 0.0for (i <- 0 until K) {tempDist += kPoints(i).squaredDist(newPoints(i))}//更新聚类中心到kPointfor (newP <- newPoints) {kPoints(newP._1) = newP._2}println("Finished iteration(delta = "+ tempDist + ")")}println("Final centers:")kPoints.foreach(println)System.exit(0)}}


1 0
原创粉丝点击