scala实现Kmeans算法

来源:互联网 发布:约爱cms源码程序.zip 编辑:程序博客网 时间:2024/06/05 08:12

  好久没有写博客了,虽然并没有多少人看。kmeans的思想大家自己去查找,我就不一一叙述了。kmeans之所以不能达到全局最优,是因为他的cost函数是一个非凸的函数,找不到最低点那个位置。kmeans的初始位置很重要,本片博客采取的就是最基本的随机生成初始中心点(我很好奇,有些人的代码就是随机生成n和点,都不带判重的),比较 好的生成算法是kmeans++,保证初始点间的距离最远。这是我初学scala一个月写的代码,还没有体会到scala的精髓,望各位指导!

import scala.collection.immutable.Vectorimport scala.io.Sourceimport scala.util.Randomimport scala.collection.mutable.ArrayBufferimport com.sun.jersey.core.spi.factory.MessageBodyFactory.DistanceComparatorobject MyKmeans {   //读取需要聚类的数据   def GetData(pathfile:String):Array[Point]={    val source=Source.fromFile(pathfile)    val lines=source.getLines().toArray    val data=lines.map { x =>new Point( x.split(" ").map { y => y.toDouble})}    println("读取数据完成")    data  }  def main(args: Array[String]): Unit = {    var data=GetData("/home/hadoop/xixi")    var k=new Kmeans(data,5,20)    k.run()    k.SaveData  }}

import scala.collection.immutable.Vectorimport scala.io.Sourceimport scala.util.Randomimport scala.collection.mutable.ArrayBufferimport com.sun.jersey.core.spi.factory.MessageBodyFactory.DistanceComparatorimport org.apache.spark.mllib.util.Saveableimport java.io.PrintWriterclass Kmeans(val data:Array[Point],val numClusters:Int,val MaxIterations:Int,    val threshold:Double=1e-4,val savepath:String="/home/hadoop/haha") {  //中心点的坐标  var CenterPoint=new Array[Point](numClusters)  //每个点对应的中心点相关信息  var Costinform=new Array[Vedist](data.length)  //构造出一个长度为len的Point数组,Point的各个量为0 ,Point为k维度  def InitArrPoint(len:Int,k:Int):Array[Point]={    var arr=new Array[Double](k)    var arrp=new Array[Point](len)    arrp.map { x => new Point(arr)}  }  //输出该数据结构中的数据,便于调试使用  def Output(data:Array[Point])  {    data.foreach { x => x.OutPut}  }  //获取初始的中心点  def InitCenterPoint(){    var ve=new ArrayBuffer[Double]    val st=System.nanoTime()    var n=0    while(ve.length<numClusters)    {      val a=(new Random()).nextInt(data.length)      if(!ve.exists {x=>x==a})      {        ve+=a        CenterPoint(n)=data(a)        n+=1      }    }    val ed=System.nanoTime()    println("--------------------------------------\n随机中心点已经生成成功,生成时间为:"+(ed-st)+"\n随机点为:")    Output(CenterPoint)  }  //找到一个点距离最近的中心  def FastSearch(point:Point,n:Int):Vedist=  {    var cost=Double.MaxValue    var k = -1    for(i<-(0 until CenterPoint.length))    {      var m=point.Distance(CenterPoint(i))      if(cost>m)      {        cost=m        k=i      }    }    val m=Vedist(k,cost)    Costinform(n)=m    m  }  //设置中心点坐标  def setCenterPoint(NewPoint:Array[Point])  {    for(i<- 0 until numClusters)      CenterPoint(i)=NewPoint(i)  }  //计算损失函数  def ComputeCost:Double=  {    var sum=0.0     Costinform.foreach { x =>sum+=x.cost}    sum  }  //kmeans函数运行主体  def run()  {    InitCenterPoint()    var k=0    var f=true    val st=System.nanoTime()    while(k<MaxIterations&&f)    {      k+=1      var NewPoint=new Array[Point](numClusters)      var SumPoint=InitArrPoint(numClusters,data(0).px.length)       //计算每个点属于哪个中心点所在的类,并且记录每个类中点的数量,与该类中所有向量的和      for(i<- 0 until data.length)      {        var cid=FastSearch(data(i), i).center_id        SumPoint(cid)+=data(i)      }      //Output(SumPoint)      //新的中心点      for(i<-0 until numClusters)        NewPoint(i)=SumPoint(i)./(Costinform.count {_.center_id==i})      //计算新的中心点与原中心点之间的聚类是否小于阈值      f=NewPoint.zip(CenterPoint).map(f=>f._1.Distance(f._2)).exists {_>threshold}      //如果符合条件则继续更新计算      if(!f)      {        for(i<-0 until numClusters)        CenterPoint(i)=NewPoint(i)        println("第"+k+"次中心点")        Output(CenterPoint)        println("第"+k+"次花费")        println(ComputeCost)      }    }    val ed=System.nanoTime()    println("Kmeans聚类时间为:"+(ed-st))  }  //保存数据  def SaveData{    val out=new PrintWriter(savepath)    out.println("中心点为:")    for(i<- 0 until CenterPoint.length)      out.println(CenterPoint(i).mkString)    out.println("花费为:")    out.println(ComputeCost)    out.println("各个点属于")    for(i<- 0 until data.length)      out.println(data(i).mkString+"          "+Costinform(i).center_id)    out.close()  }}

import scala.collection.mutable.ArrayBuffer import parquet.org.codehaus.jackson.map.ser.impl.PropertySerializerMap.Empty//定义点类  class Point(val px:Vector[Double]){    def this(p:Array[Double])    {      this(p.toVector)    }    def OutPut    {      px.foreach {x=>print(x+" ")}      println    }    def ^ :Double={      px.map { x => x*x }.sum    }    def +(that:Point):Point={      var m=new ArrayBuffer[Double]      for(i<-0 until px.length)        m+=(px(i)+that.px(i))      new Point(m.toArray)    }    def *(that:Point):Double={      var m=0.0      for(i<-0 until px.length)        m=m+px(i)*that.px(i)      m    }    def /(n:Int):Point={      var ve=new Array[Double](px.length)      for(i<-0 until px.length)        ve(i)=px(i)/n      new Point(ve)    }    def Distance(that:Point):Double=(this^ )+(that^ )-2*(that*this)    def init(len:Int):Point={      new Point(new Array[Double](len))    }    def mkString:String={      var str=""      px.foreach { x =>str+=x.toString()+" " }      str    }  }//单纯的储存信息的case类,center_id代表数据点对应的中心点,cost代表两点的花费case class Vedist(val center_id:Int,val cost:Double)
scala写的非常不地道,没有发挥函数式编程的优越性

0 0
原创粉丝点击