K 聚类分析实现类源码

来源:互联网 发布:holi天气数据来源 编辑:程序博客网 时间:2024/05/16 13:49

数据文件来自 :http://archive.ics.uci.edu/ml/datasets/Wholesale+customers?cm_mc_uid=21918109261714715776095&cm_mc_sid_50200000=1476090999


import org.apache.spark.{SparkContext, SparkConf}import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}import org.apache.spark.mllib.linalg.Vectorsobject KMeansClustering { def main (args: Array[String]) { if (args.length < 5) {    println("Usage:KMeansClustering trainingDataFilePath testDataFilePath numClusters    numIterations runTimes") sys.exit(1) } val conf = new    SparkConf().setAppName("Spark MLlib Exercise:K-Means Clustering") val sc = new SparkContext(conf)  val rawTrainingData = sc.textFile(args(0)) val parsedTrainingData =    rawTrainingData.filter(!isColumnNameLine(_)).map(line => {    Vectors.dense(line.split("\t").map(_.trim).filter(!"".equals(_)).map(_.toDouble)) }).cache()    // Cluster the data into two classes using KMeans val numClusters = args(2).toInt val numIterations = args(3).toInt val runTimes = args(4).toInt var clusterIndex:Int = 0 val clusters:KMeansModel =    KMeans.train(parsedTrainingData, numClusters, numIterations,runTimes)    println("Cluster Number:" + clusters.clusterCenters.length)    println("Cluster Centers Information Overview:") clusters.clusterCenters.foreach(    x => {    println("Center Point of Cluster " + clusterIndex + ":")    println(x) clusterIndex += 1 })    //begin to check which cluster each test data belongs to based on the clustering result val rawTestData = sc.textFile(args(1)) val parsedTestData = rawTestData.map(line =>    {    Vectors.dense(line.split("\t").map(_.trim).filter(!"".equals(_)).map(_.toDouble))    }) parsedTestData.collect().foreach(testDataLine => { val predictedClusterIndex:    Int = clusters.predict(testDataLine)    println("The data " + testDataLine.toString + " belongs to cluster " +    predictedClusterIndex) })    println("Spark MLlib K-means clustering test finished.") } private def isColumnNameLine(line:String):Boolean = { if (line != null &&    line.contains("Channel")) true else false } 


0 0
原创粉丝点击