SparkML之特征提取(一)主成分分析(PCA)

来源:互联网 发布:网络创业计划书 编辑:程序博客网 时间:2024/04/26 15:02

主成分分析(Principal Component Analysis,PCA), 将多个变量通过线性变换以选出较少个数重要变量的一种多

元统计分析方法.

--------------------------------------------目录--------------------------------------------------------

理论和数据见附录

Spark 源码(mllib包)

实验

----------------------------------------------------------------------------------------------------------

Spark 源码(mllib包)

/** * A feature transformer that projects vectors to a low-dimensional space using PCA. * * @param k number of principal components */@Since("1.4.0")class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {  require(k > 0,    s"Number of principal components must be positive but got ${k}")  /**   * Computes a [[PCAModel]] that contains the principal components of the input vectors.   *   * @param sources source vectors   */  @Since("1.4.0")  def fit(sources: RDD[Vector]): PCAModel = {    require(k <= sources.first().size,      s"source vector size is ${sources.first().size} must be greater than k=$k")    val mat = new RowMatrix(sources)    val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k)    val densePC = pc match {      case dm: DenseMatrix =>        dm      case sm: SparseMatrix =>        /* Convert a sparse matrix to dense.         *         * RowMatrix.computePrincipalComponents always returns a dense matrix.         * The following code is a safeguard.         */        sm.toDense      case m =>        throw new IllegalArgumentException("Unsupported matrix format. Expected " +          s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}")    }    val denseExplainedVariance = explainedVariance match {      case dv: DenseVector =>        dv      case sv: SparseVector =>        sv.toDense    }    new PCAModel(k, densePC, denseExplainedVariance)  }  /**   * Java-friendly version of [[fit()]]   */  @Since("1.4.0")  def fit(sources: JavaRDD[Vector]): PCAModel = fit(sources.rdd)}/** * Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA. * * @param k number of principal components. * @param pc a principal components Matrix. Each column is one principal component. */@Since("1.4.0")class PCAModel private[spark] (    @Since("1.4.0") val k: Int,    @Since("1.4.0") val pc: DenseMatrix,    @Since("1.6.0") val explainedVariance: DenseVector) extends VectorTransformer {  /**   * Transform a vector by computed Principal Components.   *   * @param vector vector to be transformed.   *               Vector must be the same length as the source vectors given to [[PCA.fit()]].   * @return transformed vector. Vector will be of length k.   */  @Since("1.4.0")  override def transform(vector: Vector): Vector = {    vector match {      case dv: DenseVector =>        pc.transpose.multiply(dv)      case SparseVector(size, indices, values) =>        /* SparseVector -> single row SparseMatrix */        val sm = Matrices.sparse(size, 1, Array(0, indices.length), indices, values).transpose        val projection = sm.multiply(pc)        Vectors.dense(projection.values)      case _ =>        throw new IllegalArgumentException("Unsupported vector format. Expected " +          s"SparseVector or DenseVector. Instead got: ${vector.getClass}")    }  }}

---------------------------------------------------------------------------------------------------------

SparkML实验

import org.apache.log4j.{Level, Logger}import org.apache.spark.mllib.feature.PCAimport org.apache.spark.mllib.linalg.Vectorsimport org.apache.spark.{SparkConf, SparkContext}object myPCA {  def main(args: Array[String]) {    val conf = new SparkConf().setAppName("PCA example").setMaster("local")    val sc = new SparkContext(conf)    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)    Logger.getLogger("org.eclipse.jetty.Server").setLevel(Level.OFF)    val data = sc.textFile("/root/application/upload/pca2.data")    //data.foreach(println)    val parseData = data.map{ line =>    val part = line.split(' ')      Vectors.dense(part.map(_.toDouble))    }    val model = new PCA(3).fit(parseData)    model.transform(parseData).foreach(println)    //--------------------------------------------------------------------------    /**      * [-198.49935555431662,61.7455925014451,-33.61561582724634]        [-142.6503762139188,42.83576581230462,-27.723300375043127]        [-94.48444346449276,37.63137787042039,-18.467916687311757]        [-93.78770648660057,53.13442729915277,-20.324679585348406]        [-115.21309309209421,64.72629901491086,-24.068684431501]        [-141.13717390563068,62.443549430022024,-32.15482042868974]        [-139.84404002633448,85.49929177772042,-26.90430756804854]        [-106.34627395862736,57.60589638943985,-23.47345414370614]        [-254.30867520979697,40.87956572432333,-12.424267061380176]        [-146.56200808994245,52.842236008590454,-16.703674457958243]        [-170.42181527333886,63.27229377718993,-21.440842300235158]        [-139.13974251002367,74.9052975468746,-12.130842693355147]        [-131.03062483262897,72.29955746812841,-15.20705763790804]        [-126.21628609915788,71.19600990352119,-11.411808043562743]        [-120.23904415710874,39.83322827884836,-26.220672650471542]        [-97.36990893617941,43.377395313806836,-17.568739657112463]      */    println("---------------------------------------------------")        sc.stop()  }}

附录

链接:http://pan.baidu.com/s/1dELByj3 密码:wsnb


0 0
原创粉丝点击