基于spark mllib的LDA模型训练Scala代码实现
来源:互联网 发布:多媒体数据挖掘 编辑:程序博客网 时间:2024/05/18 06:22
从事NLP算法工作也快一年了,主要时间花在了LDA上面,但是却一直没有好好整理一下,决心把到目前为止做的一些东西分享出来,如有疑问敬请指正。
在Github上建了一个自己的项目:CkoocNLP(去这个名字是想做一个NLP相关的技术的代码实现,不过目前上面还没有什么东西)。里面已经有基于spark的训练和预测代码实现,有兴趣的同学可以去看看,代码比较简单,可以直接checkout出来跑。
直接先上代码:
1. 入口代码
import algorithm.utils.LDAUtilsimport org.apache.log4j.{Level, Logger}import org.apache.spark.mllib.linalg.Vectorimport org.apache.spark.rdd.RDDimport org.apache.spark.{SparkConf, SparkContext}object LDATrainDemo { def main(args: Array[String]) { Logger.getRootLogger.setLevel(Level.WARN) val conf = new SparkConf().setAppName("LDATrain").setMaster("local[2]") val sc = new SparkContext(conf) //加载配置文件 val ldaUtils = LDAUtils("config/lda.properties") val args = Array("../ckooc-nlp/data/preprocess_result.txt", "models/ldaModel") val inFile = args(0) val outFile = args(1) //切分数据 val textRDD = ldaUtils.getText(sc, inFile, 36).filter(_.nonEmpty).map(_.split("\\|")).map(line => (line(0).toLong, line(1))) //训练模型 val (ldaModel, vocabulary, documents, tokens) = ldaUtils.train(sc, textRDD) //计算“文档-主题分布” val docTopics: RDD[(Long, Vector)] = ldaUtils.getDocTopics(ldaModel, documents) println("文档-主题分布:") docTopics.collect().foreach(doc => { println(doc._1 + ": " + doc._2) }) //计算“主题-词” val topicWords: Array[Array[(String, Double)]] = ldaUtils.getTopicWords(ldaModel, vocabulary.collect()) println("主题-词:") topicWords.zipWithIndex.foreach(topic => { println("Topic: " + topic._2) topic._1.foreach(word => { println(word._1 + "\t" + word._2) }) println() }) //保存模型和训练结果tokens ldaUtils.saveModel(sc, outFile, ldaModel, tokens) sc.stop() }}
主要对已经分词后的数据进行LDA训练,并保存模型。主要的处理步骤如下:
- 切分数据,获得包含主要数据内容的RDD
- 进行训练,获得LDA模型、词汇表、文本向量表示、所有切分tokens
- 获取“文档-主题分布”和“主题-词”结果,并打印输出
- 保存模型和切分tokens
说明:这里保存tokens是为了后面进行新文档预测时的文本向量表示使用的词汇表等数据与训练时保持一致。
下面分别对各个步骤进行说明
2. 模型训练
/** * LDA模型训练函数 * * @param sc SparkContext * @param rdd 输入数据 * @return (LDAModel, 词汇表) */ def train(sc: SparkContext, rdd: RDD[(Long, String)]): (LDAModel, RDD[String], RDD[(Long, Vector)], DataFrame) = { val k = config.k val maxIterations = config.maxIterations val vocabSize = config.vocabSize val algorithm = config.algorithm val alpha = config.alpha val beta = config.beta val checkpointDir = config.checkpointDir val checkpointInterval = config.checkpointInterval //将数据切分,转换为特征向量,生成词汇表,并计算数据总token数量 val featureStart = System.nanoTime() val tokens = splitLine(sc, rdd, vocabSize) val (documents, vocabulary, actualNumTokens) = featureToVector(tokens, tokens, vocabSize) val vocabRDD = sc.parallelize(vocabulary) val actualCorpusSize = documents.count() val actualVocabSize = vocabulary.length val featureElapsed = (System.nanoTime() - featureStart) / 1e9 featureInfo(actualCorpusSize, actualVocabSize, actualNumTokens, featureElapsed) val lda = new LDA() val optimizer = selectOptimizer(algorithm, actualCorpusSize) lda.setOptimizer(optimizer) .setK(k) .setMaxIterations(maxIterations) .setDocConcentration(alpha) .setTopicConcentration(beta) .setCheckpointInterval(checkpointInterval) if (checkpointDir.nonEmpty) { sc.setCheckpointDir(checkpointDir) } //训练LDA模型 val trainStart = System.nanoTime() val ldaModel = lda.run(documents) val trainElapsed = (System.nanoTime() - trainStart) / 1e9 trainInfo(documents, ldaModel, actualCorpusSize, trainElapsed) (ldaModel, vocabRDD, documents, tokens) }
- 切分tokens(splitLine)
- 文本向量表示(featureToVector)
- 设置LDA训练参数
- LDA模型训练(run)
3.1 文档-主题分布
3.2 主题-词
以上两个结果仅展示部分数据
0 0
- 基于spark mllib的LDA模型训练Scala代码实现
- 基于spark mllib的LDA模型训练源码解析
- Spark MLlib LDA主题模型
- Spark MLlib LDA 基于GraphX实现原理及源码分析
- scala---文档主题生成模型(LDA)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 基于Spark MLlib平台和基于模型的协同过滤算法的电影推荐系统(二)代码实现
- 文档主题生成模型(LDA)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 分享Spark MLlib训练的广告点击率预测模型
- LDA主题模型的java代码实现
- LDA的python实现之模型参数训练
- LDA的python实现之模型参数训练
- Spark MLlib LDA 源码解析
- Spark MLlib LDA 源码解析
- 使用Spark MLlib训练和提供自然语言处理模型
- 利用spark的mllib构建GBDT模型
- 干货:基于Spark Mllib的SparkNLP库。
- 基于Spark Mllib的文本分类
- 基于Spark Streaming和Spark MLlib实现文本情感分析
- windows下创建并使用静态链接库(.lib)
- zcat,zgrep用法
- PHP用redis实现多进程队列
- Spark的join与cogroup简单示例
- linux 常用命令二 网络
- 基于spark mllib的LDA模型训练Scala代码实现
- Android中AIDL的实现使用
- Android Studio安装后的一些必要设置
- C# 实验五--平面直角坐标系
- 【学习笔记----数据结构04-单循环链表】
- 开通C博客了
- 5.4用形态学滤波器检测边缘和角点
- mybatis入门到精通学习文章总结
- 常用时间处理方法:时间戳和格式化时间之间转换;时间比大小