Spark 中LocalKmeans算法详解

来源:互联网 发布:柠檬网络电视 柠檬tv 编辑:程序博客网 时间:2024/06/05 14:15


 

一、Kmeans算法思想

Kmeans算法的具体思想这里省略。

Kmeans算法实现步骤一般如下:

1、从D中随机取k个元素,作为k个簇的各自的中心,或者随机生成k个中心元素

2、分别计算剩下的元素到k个簇中心的相异度,将这些元素分别划归到相异度最低的簇。

3、根据聚类结果,重新计算k个簇各自的中心,计算方法是取簇中所有元素各自维度的算术平均数。

4、将D中全部元素按照新的中心重新聚类。

5、重复第4步,直到聚类结果不再变化。

6、将结果输出。

 二、代码分析

关于generateData函数的解析:

这个函数的作用就是产生ND维的数据,程序中的变量已经定义好了

N = 1000,即1000个点

R = 1000,规约范围

D = 10,数据的维度,即10维的数据

 def generateData = {

    def generatePoint(i: Int) = {

      Vector(D, _ => rand.nextDouble * R)

    }

    Array.tabulate(N)(generatePoint)

  }

 

首先看几个函数:

rand函数:这个是随机产生Double类型的数据,rand.NextDouble产生的是0~1double类型的数据

Vector函数:此处的Vectororg.apache.spark.util.Vector类,代码中用到的是Vector的的apply方法,如下:

def apply(length: Int, initializer: Int => Double): Vector = {

    val elements: Array[Double] = Array.tabulate(length)(initializer)

    return new Vector(elements)

  }

这个apply方法,返回结果是一个Vector,Vector中的元素是一个Array,Array中有lengthdouble类型的元素,主要是依靠tabulate函数产生的。

tabulate函数:

 

14

def tabulate[T]( n: Int )(f: (Int)=> T): Array[T]
返回包含一个给定的函数的值超过从0开始的范围内的整数值的数组。

15

def tabulate[T]( n1: Int, n2: Int )( f: (Int, Int ) => T): Array[Array[T]]
返回一个包含给定函数的值超过整数值从0开始范围的二维数组。

 

根据定义,就是产生一个数组,数组中元素的个数为n,元素的范围为0到给定的数值范围,返回的类型为Array.

这样,Array.tabulate(length)(initializer)就是返回一个Array,元素个数为length,元素的范围为initializer。当我们调用时(Vector(D, _ => rand.nextDouble * R)),initializer实际为_ => rand.nextDouble * R,也就是0~1000double类型的数字。同理,Array.tabulate(N)(generatePoint),就是返回一个Array,Array中有N个元素,元素类型为generatePoint,generatePoint最终返回的是Vector的数据,每个Vector中包含一个ArrayArray中包含DDoule类型的数据。可以Ctr聚焦变量查看类型。

 

初始化聚类中心:

首先产生个数据,产生数据使用data(Int),因为前面有val data = generateData,所以调用的是generateData函数,传入任意一个整数,则和generatePoint(i: Int)对应,产生一个D维的数据。

 while (points.size < K) {

      //rand.nextInt(N)返回一个0的随机整数

      points.add(data(rand.nextInt(N)))

}

//pointsHashSet[Vector]类型),转换成Iterator类型(iterIterator[Vector]类型)

 val iter = points.iterator

//将聚类中心存入kPoints中,kPointsHashMap[Int, Vector]类型

    for (i <- 1 to points.size) {

      kPoints.put(i, iter.next())

    }

    println("Initial centers: " + kPoints)

 

 

 

closestPoint函数:

 

def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = {

    var index = 0

    var bestIndex = 0

    var closest = Double.PositiveInfinity

    for (i <- 1 to centers.size) {

      val vCurr = centers.get(i).get

      val tempDist = p.squaredDist(vCurr)

      if (tempDist < closest) {

        closest = tempDist

        bestIndex = i

      }

}

P为传入的每个数据点,centers为聚类中心,及kPoints中的点,vCurr为获取到的kPoints中数据点,centers.get会得到一个Some对象,包含数据点的信息,再用get就可得到数据点的信息,即10维的坐标。利用squaredDist函数计算pvCurr的距离,最终找到和p点距离最近的聚类中心的位置(key)并返回。

通过下面的函数,

var closest = data.map (p => ( closestPoint(p, kPoints), (p, 1)))

得到和data中每个点距离最近的聚类中心的位置,并将p点的坐标记录下来,最终closet中存储的数据为(聚类中心位置(110的数字),(P点的坐标,1))【即(8,((751.2804067674601, 571.0403484148671, 580.0248845020607, 752.509948590651, 31.41823882658079, 357.91991947712864, 817.7969308356393, 417.68754675291876, 974.0356814958814, 713.4062578232291),1))

  

 var mappings = closest.groupBy[Int] (x=>x._1)

和同一个聚类中心最近的点进行归类,将他们分到同一类中

groupBy函数:

 // groupBy : groupBy[K](f: (A) ⇒ K): Map[K, List[A]] 将列表进行分组,分组的依据是应用f在元素上后产生的新元素

     //val data = List(("HomeWay","Male"),("XSDYM","Femail"),("Mr.Wang","Male"))

    //  val group1 = data.groupBy(_._2) // = Map("Male" -> List(("HomeWay","Male"),("Mr.Wang","Male")),"Female" -> List(("XSDYM","Femail")))

例如:有三个点和kPoints中的第8个点最近,将他们归到一类中

(8,((751.2804067674601, 571.0403484148671, 580.0248845020607, 752.509948590651, 31.41823882658079, 357.91991947712864, 817.7969308356393, 417.68754675291876, 974.0356814958814, 713.4062578232291),1))

(8,((469.02252061556857, 827.3224240149951, 151.03155452875828, 833.8662354441657, 460.30637266116116, 280.5719916728102, 195.964207423156, 179.27344087491736, 865.6867963273522, 486.59066182619404),1))

(8,((548.9937328293971, 328.5599644454902, 720.1246754012255, 387.79615803820235, 15.246621455438758, 102.8983670344621, 103.85689724130098, 173.2256480735601, 897.6338235309476, 418.85841914666554),1))

 

 

 var pointStats = mappings.map(pair => pair._2.reduceLeft [(Int, (Vector, Int))] {case ((id1, (x1, y1)), (id2, (x2, y2))) => (id1, (x1 + x2, y1+y2))})

将对应的坐标和最后标记的1(这儿就是为了方便计数,统计有几个点属于同一个聚类中心)对应相加pointsStats类型为Map[Int,(Vector,Int)]。例如和上面数据相对应的结果为:

(8,((1769.2966602124256, 1726.9227368753523, 1451.1811144320445, 1974.1723420730189, 506.9712329431807, 741.390278184401, 1117.6180355000963, 770.1866357013963, 2737.356301354181, 1618.8553387960885),3))

 

var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)}

对同一聚类中心的各个点求均值,得到新的聚类中心,newPoints类型为Map[Int,Vector].

tempDist为新的聚类中心和原来的聚类中心的距离,根据while循环的条件,距离小于convergeDist时退出循环,即认为聚类中心不再变化,每次求得的新的聚类中心,要重新加入到kPoints中,一直迭代,直到循环结束。

    tempDist = 0.0

      for (mapping <- newPoints) {

        tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2)

      }

 

      for (newP <- newPoints) {

        kPoints.put(newP._1, newP._2)

      }

三、代码

package org.apache.spark.examplesimport java.util.Randomimport org.apache.spark.util.Vectorimport org.apache.spark.SparkContext._import scala.collection.mutable.HashMapimport scala.collection.mutable.HashSet/** * K-means clustering. */object LocalKMeans {  val N = 1000  val R = 1000    // Scaling factor  val D = 10  val K = 10  val convergeDist = 0.001  val rand = new Random(42)  def generateData = {    def generatePoint(i: Int) = {      Vector(D, _ => rand.nextDouble * R)    }    Array.tabulate(N)(generatePoint)  }  def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = {    var index = 0    var bestIndex = 0    var closest = Double.PositiveInfinity    for (i <- 1 to centers.size) {      val vCurr = centers.get(i).get      val tempDist = p.squaredDist(vCurr)      if (tempDist < closest) {        closest = tempDist        bestIndex = i      }    }    return bestIndex  }  def main(args: Array[String]) {    val data = generateData    var points = new HashSet[Vector]    var kPoints = new HashMap[Int, Vector]    var tempDist = 1.0    while (points.size < K) {      points.add(data(rand.nextInt(N)))    }    val iter = points.iterator    for (i <- 1 to points.size) {      kPoints.put(i, iter.next())    }    println("Initial centers: " + kPoints)    while(tempDist > convergeDist) {      var closest = data.map (p => (closestPoint(p, kPoints), (p, 1)))      var mappings = closest.groupBy[Int] (x => x._1)      var pointStats = mappings.map(pair => pair._2.reduceLeft [(Int, (Vector, Int))] {case ((id1, (x1, y1)), (id2, (x2, y2))) => (id1, (x1 + x2, y1+y2))})      var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)}      tempDist = 0.0      for (mapping <- newPoints) {        tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2)      }      for (newP <- newPoints) {        kPoints.put(newP._1, newP._2)      }    }    println("Final centers: " + kPoints)  }}


0 0
原创粉丝点击