spark PIPELINE使用

来源:互联网 发布:艾默生dcs组态软件 编辑:程序博客网 时间:2024/04/28 14:57

ML中的pipeline估计是参考了py的Scipy等把

1.PIPELINE的主要部分就是

val pipeline = new Pipeline()  .setStages(Array(tokenizer, hashingTF, lr))// Fit the pipeline to training documents.val model = pipeline.fit(training)

 

2.将各个计算阶段按照stages顺序,整个阶段就是依靠DF的col,设置input,output

(1).构造tokenizer阶段

val training = sqlContext.createDataFrame(Seq(  (0L, "a b c d e spark", 1.0),  (1L, "b d", 0.0),  (2L, "spark f g h", 1.0),  (3L, "hadoop mapreduce", 0.0))).toDF("id", "text", "label")// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.val tokenizer = new Tokenizer()  .setInputCol("text")  .setOutputCol("words")

 

(2).TF阶段

val hashingTF = new HashingTF()  .setNumFeatures(1000)  .setInputCol(tokenizer.getOutputCol)  .setOutputCol("features")

 

(3).lr阶段

val lr = new LogisticRegression()  .setMaxIter(10)  .setRegParam(0.01)

 

3.我们看看pipeline.fit做了什么事情,就是如何将各个阶段连接起来的

(1).将各各阶段的不同类分开,这里先找出评估模型,就是LogisticRegression(LogisticRegression是继承Estimator)

theStages.view.zipWithIndex.foreach { case (stage, index) =>  stage match {    case _: Estimator[_] =>      indexOfLastEstimator = index    case _ =>  }}

(2).Estimator类型的执行fit,Transformer类型的执性transformer

theStages.view.zipWithIndex.foreach { case (stage, index) =>  if (index <= indexOfLastEstimator) {    val transformer = stage match {      case estimator: Estimator[_] =>        estimator.fit(curDataset)      case t: Transformer =>        t      case _ =>        throw new IllegalArgumentException(          s"Do not support stage $stage of type ${stage.getClass}")    }    if (index < indexOfLastEstimator) {      curDataset = transformer.transform(curDataset)    }    transformers += transformer  } else {    transformers += stage.asInstanceOf[Transformer]  }}

(3).最后构造出PipelineModel

new PipelineModel(uid, transformers.toArray).setParent(this)

 

 

 

0 0
原创粉丝点击