Spark 中LocalKmeans算法详解
来源:互联网 发布:柠檬网络电视 柠檬tv 编辑:程序博客网 时间:2024/06/05 14:15
一、Kmeans算法思想
Kmeans算法的具体思想这里省略。
Kmeans算法实现步骤一般如下:
1、从D中随机取k个元素,作为k个簇的各自的中心,或者随机生成k个中心元素。
2、分别计算剩下的元素到k个簇中心的相异度,将这些元素分别划归到相异度最低的簇。
3、根据聚类结果,重新计算k个簇各自的中心,计算方法是取簇中所有元素各自维度的算术平均数。
4、将D中全部元素按照新的中心重新聚类。
5、重复第4步,直到聚类结果不再变化。
6、将结果输出。
二、代码分析
关于generateData函数的解析:
这个函数的作用就是产生N个D维的数据,程序中的变量已经定义好了
N = 1000,即1000个点
R = 1000,规约范围
D = 10,数据的维度,即10维的数据
def generateData = {
def generatePoint(i: Int) = {
Vector(D, _ => rand.nextDouble * R)
}
Array.tabulate(N)(generatePoint)
}
首先看几个函数:
rand函数:这个是随机产生Double类型的数据,rand.NextDouble产生的是0~1的double类型的数据
Vector函数:此处的Vector是org.apache.spark.util.Vector类,代码中用到的是Vector的的apply方法,如下:
def apply(length: Int, initializer: Int => Double): Vector = {
val elements: Array[Double] = Array.tabulate(length)(initializer)
return new Vector(elements)
}
这个apply方法,返回结果是一个Vector,Vector中的元素是一个Array,Array中有length个double类型的元素,主要是依靠tabulate函数产生的。
tabulate函数:
14
def tabulate[T]( n: Int )(f: (Int)=> T): Array[T]
返回包含一个给定的函数的值超过从0开始的范围内的整数值的数组。
15
def tabulate[T]( n1: Int, n2: Int )( f: (Int, Int ) => T): Array[Array[T]]
返回一个包含给定函数的值超过整数值从0开始范围的二维数组。
根据定义,就是产生一个数组,数组中元素的个数为n,元素的范围为0到给定的数值范围,返回的类型为Array.
这样,Array.tabulate(length)(initializer)就是返回一个Array,元素个数为length,元素的范围为initializer。当我们调用时(Vector(D, _ => rand.nextDouble * R)),initializer实际为_ => rand.nextDouble * R,也就是0~1000的double类型的数字。同理,Array.tabulate(N)(generatePoint),就是返回一个Array,Array中有N个元素,元素类型为generatePoint,而generatePoint最终返回的是Vector的数据,每个Vector中包含一个Array,Array中包含D个Doule类型的数据。可以Ctr聚焦变量查看类型。
初始化聚类中心:
首先产生K 个数据,产生数据使用data(Int),因为前面有val data = generateData,所以调用的是generateData函数,传入任意一个整数,则和generatePoint(i: Int)对应,产生一个D维的数据。
while (points.size < K) {
//rand.nextInt(N)返回一个0到N 的随机整数
points.add(data(rand.nextInt(N)))
}
//将points(HashSet[Vector]类型),转换成Iterator类型(iter为Iterator[Vector]类型)
val iter = points.iterator
//将聚类中心存入kPoints中,kPoints为HashMap[Int, Vector]类型
for (i <- 1 to points.size) {
kPoints.put(i, iter.next())
}
println("Initial centers: " + kPoints)
closestPoint函数:
def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = {
var index = 0
var bestIndex = 0
var closest = Double.PositiveInfinity
for (i <- 1 to centers.size) {
val vCurr = centers.get(i).get
val tempDist = p.squaredDist(vCurr)
if (tempDist < closest) {
closest = tempDist
bestIndex = i
}
}
P为传入的每个数据点,centers为聚类中心,及kPoints中的点,vCurr为获取到的kPoints中数据点,centers.get会得到一个Some对象,包含数据点的信息,再用get就可得到数据点的信息,即10维的坐标。利用squaredDist函数计算p和vCurr的距离,最终找到和p点距离最近的聚类中心的位置(key)并返回。
通过下面的函数,
var closest = data.map (p => ( closestPoint(p, kPoints), (p, 1)))
得到和data中每个点距离最近的聚类中心的位置,并将p点的坐标记录下来,最终closet中存储的数据为(聚类中心位置(1到10的数字),(P点的坐标,1))【即(8,((751.2804067674601, 571.0403484148671, 580.0248845020607, 752.509948590651, 31.41823882658079, 357.91991947712864, 817.7969308356393, 417.68754675291876, 974.0356814958814, 713.4062578232291),1))】
var mappings = closest.groupBy[Int] (x=>x._1)
和同一个聚类中心最近的点进行归类,将他们分到同一类中
groupBy函数:
// groupBy : groupBy[K](f: (A) ⇒ K): Map[K, List[A]] 将列表进行分组,分组的依据是应用f在元素上后产生的新元素
//val data = List(("HomeWay","Male"),("XSDYM","Femail"),("Mr.Wang","Male"))
// val group1 = data.groupBy(_._2) // = Map("Male" -> List(("HomeWay","Male"),("Mr.Wang","Male")),"Female" -> List(("XSDYM","Femail")))
例如:有三个点和kPoints中的第8个点最近,将他们归到一类中
(8,((751.2804067674601, 571.0403484148671, 580.0248845020607, 752.509948590651, 31.41823882658079, 357.91991947712864, 817.7969308356393, 417.68754675291876, 974.0356814958814, 713.4062578232291),1))
(8,((469.02252061556857, 827.3224240149951, 151.03155452875828, 833.8662354441657, 460.30637266116116, 280.5719916728102, 195.964207423156, 179.27344087491736, 865.6867963273522, 486.59066182619404),1))
(8,((548.9937328293971, 328.5599644454902, 720.1246754012255, 387.79615803820235, 15.246621455438758, 102.8983670344621, 103.85689724130098, 173.2256480735601, 897.6338235309476, 418.85841914666554),1))
var pointStats = mappings.map(pair => pair._2.reduceLeft [(Int, (Vector, Int))] {case ((id1, (x1, y1)), (id2, (x2, y2))) => (id1, (x1 + x2, y1+y2))})
将对应的坐标和最后标记的1(这儿就是为了方便计数,统计有几个点属于同一个聚类中心)对应相加pointsStats类型为Map[Int,(Vector,Int)]。例如和上面数据相对应的结果为:
(8,((1769.2966602124256, 1726.9227368753523, 1451.1811144320445, 1974.1723420730189, 506.9712329431807, 741.390278184401, 1117.6180355000963, 770.1866357013963, 2737.356301354181, 1618.8553387960885),3))
var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)}
对同一聚类中心的各个点求均值,得到新的聚类中心,newPoints类型为Map[Int,Vector].
tempDist为新的聚类中心和原来的聚类中心的距离,根据while循环的条件,距离小于convergeDist时退出循环,即认为聚类中心不再变化,每次求得的新的聚类中心,要重新加入到kPoints中,一直迭代,直到循环结束。
tempDist = 0.0
for (mapping <- newPoints) {
tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2)
}
for (newP <- newPoints) {
kPoints.put(newP._1, newP._2)
}
三、代码
package org.apache.spark.examplesimport java.util.Randomimport org.apache.spark.util.Vectorimport org.apache.spark.SparkContext._import scala.collection.mutable.HashMapimport scala.collection.mutable.HashSet/** * K-means clustering. */object LocalKMeans { val N = 1000 val R = 1000 // Scaling factor val D = 10 val K = 10 val convergeDist = 0.001 val rand = new Random(42) def generateData = { def generatePoint(i: Int) = { Vector(D, _ => rand.nextDouble * R) } Array.tabulate(N)(generatePoint) } def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = { var index = 0 var bestIndex = 0 var closest = Double.PositiveInfinity for (i <- 1 to centers.size) { val vCurr = centers.get(i).get val tempDist = p.squaredDist(vCurr) if (tempDist < closest) { closest = tempDist bestIndex = i } } return bestIndex } def main(args: Array[String]) { val data = generateData var points = new HashSet[Vector] var kPoints = new HashMap[Int, Vector] var tempDist = 1.0 while (points.size < K) { points.add(data(rand.nextInt(N))) } val iter = points.iterator for (i <- 1 to points.size) { kPoints.put(i, iter.next()) } println("Initial centers: " + kPoints) while(tempDist > convergeDist) { var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) var mappings = closest.groupBy[Int] (x => x._1) var pointStats = mappings.map(pair => pair._2.reduceLeft [(Int, (Vector, Int))] {case ((id1, (x1, y1)), (id2, (x2, y2))) => (id1, (x1 + x2, y1+y2))}) var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)} tempDist = 0.0 for (mapping <- newPoints) { tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2) } for (newP <- newPoints) { kPoints.put(newP._1, newP._2) } } println("Final centers: " + kPoints) }}
- Spark 中LocalKmeans算法详解
- LocalKMeans
- spark中算子详解:aggregateByKey
- spark中算子详解:combineByKey
- kmeans算法详解与spark实战
- spark mllib中ALS算法思想
- Spark MLlib 中Isotonic regression算法简介
- spark中协同过滤算法分析
- hadoop常用算法在spark中实现
- Spark中executor-memory参数详解
- spark算法
- Spark详解
- spark详解
- spark详解
- spark详解
- Spark MLlib 中power iteration clustering (PIC)算法简介
- 如何解释spark mllib中ALS算法的原理?
- spark的Graphx中subGraph算法的改进
- Android+xml;
- 机器学习
- 内存数据库Redis小Demo 包括持久性测试
- java、js处理科学计数法的问题
- kernel文件动态调试功能 -- dynamic_debug 打开及半闭
- Spark 中LocalKmeans算法详解
- 基于设备树的GPIO驱动(通过系统节点控制)
- Activity启动模式
- 来北京的第一场雪
- 【Leetcode】之Remove Nth Node From End of List
- SQL UNION 和 UNION ALL 操作符
- Android Studio系列教程(一)一--下载和安装
- 修改ndk编译时的线程数
- sql语句备份