TD-IDF在spark中的使用(ml方式)

来源:互联网 发布:华为公司社会关系网络 编辑:程序博客网 时间:2024/06/05 23:00

上一篇 文章提到了TD-IDF的原理和大致使用方式, 现在我写了一个比较完整的例子来展示一下, 该例子包含了数据导入(为了统一, 将文件导入了数据库),处理, 以及结果导出功能.

import org.apache.spark.mllib.linalg.Vectorimport com.zte.bigdata.vmax.machinelearning.common.{LogSupport, CreateSparkContext}import org.apache.spark.ml.feature.{IDF, HashingTF, Tokenizer}import org.apache.spark.sql.DataFrameimport org.apache.spark.sql.functions.{col, udf}import scala.collection.mutable.ArrayBufferimport org.apache.spark.mllib.feature.{HashingTF => MllibHashingTF}import scala.util.Try/** table_tf_idf表是输入表, 其结构为theme+content, content表示主题theme下的某一篇文章,* 下图显示了一个主题, 其实有很多.* * +--------------+--------------------+* |         theme|             content|* +--------------+--------------------+* |comp.windows.x|From: chongo@toad...|* |comp.windows.x|From: chongo@toad...|* |comp.windows.x|From: steve@ecf.t...|* |comp.windows.x|From: ware@cis.oh...|* |comp.windows.x|From: stevevr@tt7...|* ...*/class TFIDFModel extends CreateSparkContext {// CreateSparkContext中包含sc, hc(sqlContext)的创建  // 保证RDD可以转换为DataFrame  import hc.implicits._  def calTFIDF(topN:Int): DataFrame = {    hc.sql("use database")    // 因为要提取theme的关键词, 所以需要先做聚合,将相同主题文章放到一起    // select concat_ws(" ".collect_list(content) group by theme) from table_tf_idf 应该也可以    val dataGroupByTheme = hc.sql(s"select theme,content from table_tf_idf").rdd    .map(row => (row.getString(0), row.getString(1).replaceAll("\\p{Punct}", " ")))// 去掉英文标点    .reduceByKey((x, y) => x + y) //聚合    .toDF("theme", "content") //转回DF    // 用于做分词, 结果可见前文    val tokenizer = new Tokenizer().setInputCol("content").setOutputCol("words")    val wordsData = tokenizer.transform(dataGroupByTheme)    // 获取单词->Hasing的映射(单词 -> 哈希值)    //此处HashingTF属于mllib, 默认numFeatures为1<<20, 但是ml下的hashingTF却是1<<18, 要统一才能确保hash结果一致    val mllibHashingTF = new MllibHashingTF(1 << 18)     val mapWords = wordsData.select("words").rdd.map(row => row.getAs[ArrayBuffer[String]](0))    .flatMap(x => x).map(w => (mllibHashingTF.indexOf(w), w)).collect.toMap    // 计算出TF值    val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures")    val featurizedData = hashingTF.transform(wordsData)    // 计算IDF值, 实际计算出来的形式为稀疏矩阵 [标签总数,[标签1,标签2,...],[标签1的值,标签2的值,...]]    val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")    val idfModel = idf.fit(featurizedData)    val rescaledData = idfModel.transform(featurizedData)    // 将得到的数据按照tf-idf值从大到小排序,提取topN,并且将hashing id 转为单词    val takeTopN = udf { (v: Vector) =>       (v.toSparse.indices zip v.toSparse.values)      .sortBy(-_._2) //负值就是从大到小      .take(topN)      .map(x => mapWords.getOrElse(x._1, "null") + ":" + f"${x._2}%.3f".toString) // 冒号分隔单词和值,值取小数后三位      .mkString(";") } // 词语和值的对以;隔开(别用逗号,会与hive表格式的TERMINATED BY ','冲突)    rescaledData.select(col("theme"), takeTopN(col("features")).as("features"))    //    rescaledData.select("features", "theme").take(3).foreach(println)  }  // 将原始文件写入数据库  def data2DB() = {    case class Article(theme:String, content:String) // 每个hdfs路径主题下都有很多文章    val doc1 = sc.wholeTextFiles("/sdl/data/20news-bydate-train/comp.windows.x")    val doc2 = sc.wholeTextFiles("/sdl/data/20news-bydate-train/comp.graphics")    val doc3 = sc.wholeTextFiles("/sdl/data/20news-bydate-train/misc.forsale")    val doc = doc1.union(doc2).union(doc3)  // 创建df, 并且把文章的换行和回车去掉    val df =  sqlContext.createDataFrame(doc.map(x => Article(x._1.split("/").init.last,x._2.replaceAll("\n|\r",""))))    df.registerTempTable("df")    // df入库的方式与 saveOutPut 函数相同,省略  }  // 将dataframe格式数据保存到数据库  def saveOutPut(output:String) = {    log.debug(s"TF-IDF result writed to the table $output ...")    calTFIDF().registerTempTable("tfidf_result_table")    hc.sql("use database")    hc.sql( s"""drop table if exists $output""")    hc.sql(      s"""create table IF NOT EXISTS $output (      theme  String,      tfidf_result   String      )      ROW FORMAT DELIMITED FIELDS TERMINATED BY ','""")    val sql = s"insert overwrite table $output select theme,features from tfidf_result_table"    log.debug("load data to table sql: " + sql)    try {      hc.sql(sql)    } catch {      case e: Exception => log.error("load data to table error")        throw e    }  }}

保存到表里的数据如下, 取top3

这里写图片描述

0 0
原创粉丝点击