scala实现Kmeans算法
来源:互联网 发布:约爱cms源码程序.zip 编辑:程序博客网 时间:2024/06/05 08:12
好久没有写博客了,虽然并没有多少人看。kmeans的思想大家自己去查找,我就不一一叙述了。kmeans之所以不能达到全局最优,是因为他的cost函数是一个非凸的函数,找不到最低点那个位置。kmeans的初始位置很重要,本片博客采取的就是最基本的随机生成初始中心点(我很好奇,有些人的代码就是随机生成n和点,都不带判重的),比较 好的生成算法是kmeans++,保证初始点间的距离最远。这是我初学scala一个月写的代码,还没有体会到scala的精髓,望各位指导!
import scala.collection.immutable.Vectorimport scala.io.Sourceimport scala.util.Randomimport scala.collection.mutable.ArrayBufferimport com.sun.jersey.core.spi.factory.MessageBodyFactory.DistanceComparatorobject MyKmeans { //读取需要聚类的数据 def GetData(pathfile:String):Array[Point]={ val source=Source.fromFile(pathfile) val lines=source.getLines().toArray val data=lines.map { x =>new Point( x.split(" ").map { y => y.toDouble})} println("读取数据完成") data } def main(args: Array[String]): Unit = { var data=GetData("/home/hadoop/xixi") var k=new Kmeans(data,5,20) k.run() k.SaveData }}
import scala.collection.immutable.Vectorimport scala.io.Sourceimport scala.util.Randomimport scala.collection.mutable.ArrayBufferimport com.sun.jersey.core.spi.factory.MessageBodyFactory.DistanceComparatorimport org.apache.spark.mllib.util.Saveableimport java.io.PrintWriterclass Kmeans(val data:Array[Point],val numClusters:Int,val MaxIterations:Int, val threshold:Double=1e-4,val savepath:String="/home/hadoop/haha") { //中心点的坐标 var CenterPoint=new Array[Point](numClusters) //每个点对应的中心点相关信息 var Costinform=new Array[Vedist](data.length) //构造出一个长度为len的Point数组,Point的各个量为0 ,Point为k维度 def InitArrPoint(len:Int,k:Int):Array[Point]={ var arr=new Array[Double](k) var arrp=new Array[Point](len) arrp.map { x => new Point(arr)} } //输出该数据结构中的数据,便于调试使用 def Output(data:Array[Point]) { data.foreach { x => x.OutPut} } //获取初始的中心点 def InitCenterPoint(){ var ve=new ArrayBuffer[Double] val st=System.nanoTime() var n=0 while(ve.length<numClusters) { val a=(new Random()).nextInt(data.length) if(!ve.exists {x=>x==a}) { ve+=a CenterPoint(n)=data(a) n+=1 } } val ed=System.nanoTime() println("--------------------------------------\n随机中心点已经生成成功,生成时间为:"+(ed-st)+"\n随机点为:") Output(CenterPoint) } //找到一个点距离最近的中心 def FastSearch(point:Point,n:Int):Vedist= { var cost=Double.MaxValue var k = -1 for(i<-(0 until CenterPoint.length)) { var m=point.Distance(CenterPoint(i)) if(cost>m) { cost=m k=i } } val m=Vedist(k,cost) Costinform(n)=m m } //设置中心点坐标 def setCenterPoint(NewPoint:Array[Point]) { for(i<- 0 until numClusters) CenterPoint(i)=NewPoint(i) } //计算损失函数 def ComputeCost:Double= { var sum=0.0 Costinform.foreach { x =>sum+=x.cost} sum } //kmeans函数运行主体 def run() { InitCenterPoint() var k=0 var f=true val st=System.nanoTime() while(k<MaxIterations&&f) { k+=1 var NewPoint=new Array[Point](numClusters) var SumPoint=InitArrPoint(numClusters,data(0).px.length) //计算每个点属于哪个中心点所在的类,并且记录每个类中点的数量,与该类中所有向量的和 for(i<- 0 until data.length) { var cid=FastSearch(data(i), i).center_id SumPoint(cid)+=data(i) } //Output(SumPoint) //新的中心点 for(i<-0 until numClusters) NewPoint(i)=SumPoint(i)./(Costinform.count {_.center_id==i}) //计算新的中心点与原中心点之间的聚类是否小于阈值 f=NewPoint.zip(CenterPoint).map(f=>f._1.Distance(f._2)).exists {_>threshold} //如果符合条件则继续更新计算 if(!f) { for(i<-0 until numClusters) CenterPoint(i)=NewPoint(i) println("第"+k+"次中心点") Output(CenterPoint) println("第"+k+"次花费") println(ComputeCost) } } val ed=System.nanoTime() println("Kmeans聚类时间为:"+(ed-st)) } //保存数据 def SaveData{ val out=new PrintWriter(savepath) out.println("中心点为:") for(i<- 0 until CenterPoint.length) out.println(CenterPoint(i).mkString) out.println("花费为:") out.println(ComputeCost) out.println("各个点属于") for(i<- 0 until data.length) out.println(data(i).mkString+" "+Costinform(i).center_id) out.close() }}
import scala.collection.mutable.ArrayBuffer import parquet.org.codehaus.jackson.map.ser.impl.PropertySerializerMap.Empty//定义点类 class Point(val px:Vector[Double]){ def this(p:Array[Double]) { this(p.toVector) } def OutPut { px.foreach {x=>print(x+" ")} println } def ^ :Double={ px.map { x => x*x }.sum } def +(that:Point):Point={ var m=new ArrayBuffer[Double] for(i<-0 until px.length) m+=(px(i)+that.px(i)) new Point(m.toArray) } def *(that:Point):Double={ var m=0.0 for(i<-0 until px.length) m=m+px(i)*that.px(i) m } def /(n:Int):Point={ var ve=new Array[Double](px.length) for(i<-0 until px.length) ve(i)=px(i)/n new Point(ve) } def Distance(that:Point):Double=(this^ )+(that^ )-2*(that*this) def init(len:Int):Point={ new Point(new Array[Double](len)) } def mkString:String={ var str="" px.foreach { x =>str+=x.toString()+" " } str } }//单纯的储存信息的case类,center_id代表数据点对应的中心点,cost代表两点的花费case class Vedist(val center_id:Int,val cost:Double)scala写的非常不地道,没有发挥函数式编程的优越性
0 0
- scala实现Kmeans算法
- Spark:Scala实现KMeans算法
- Scala语言实现Kmeans聚类算法
- C++实现KMeans算法
- kmeans算法java实现
- WPF实现KMEANS算法
- matlab实现kmeans算法
- Java实现Kmeans算法
- kMeans算法JAVA实现
- Kmeans算法实现
- Hadoop 实现kmeans 算法
- kmeans算法及其实现
- Kmeans算法java实现
- KMeans算法的实现
- python实现kmeans算法
- Kmeans算法及实现
- Kmeans算法的实现二
- kmeans算法的java实现
- shell的ps命令参数列表解释说明
- git am, git apply, git format-patch,git diff 用法
- 缓存淘汰算法--LRU算法
- libcurl的编译及使用
- Android 6.0 以上版本提示“检测到屏幕叠加层”的问题,规避方法
- scala实现Kmeans算法
- 不在以下合法域名列表中,请参考文档:https://mp.weixin.qq.com/debug/wxadoc/dev/api/network-request.html
- ios UITableView单元格多选框的实现
- 指针习题1
- 欢迎使用CSDN-markdown编辑器
- githup上Android APP好用的文本工具
- 高速公路ETC卡签之我见7-用户卡发行
- 点击下拉框其他地方下拉框收起
- java中判断文件是否为空内容