Spark MLlib — EMLDA
来源:互联网 发布:mac win 共享文件夹 编辑:程序博客网 时间:2024/06/06 04:37
LDA(Latent Dirichlet allocation)是一种主题模型,它可以将文档集中每篇文档的主题按照概率分布的形式给出,也即根据给定的一篇文档,推测其主题分布。同时它是一种无监督学习算法,在训练时不需要手工标注的训练集,需要的仅仅是文档集以及指定主题的数量k即可。此外LDA的另一个优点则是,对于每一个主题均可找出一些词语来描述它。本文主要介绍LDA涉及的数学知识以及Spark MLlib中基于Graphx的实现方式。
1.理论基础
这部分内容主要参考《LDA数学八卦》和《通俗理解LDA主题模型》,涉及到数学内容真是挺多的,网上大多数博文的阐述过程大多一上来就摆出一大堆理论基础,自己学习的过程中也会觉得一开始就很吃力,而且不能把众多理论和最终的模型联系起来,所以一直想找到一种易于理解的方式来阐述该模型。
1.1 LDA引入— Dirichlet先验分布
在NLP领域,文本的表现形式通常是有序的词序列,即
1.1.1 主题无关模型
- Unigram Model:投掷一个V面的骰子,重复N次生成词序列,则这N个随机变量(即生成的N个单词)是独立同分布的,该分布的概率函数即为骰子每一面出现的概率为
p⃗ ={p1,p2,...pV} 。则词典中每个单词出现次数(记为n⃗ =n1,n2,...,nV )的联合分布满足Multinomial分布(多项分布),其概率密度函数为:p(n⃗ )=Mult(n⃗ |p⃗ ,N)=(Nn⃗ )∏k=1Vpnkk=N!n1!n2!...nV!pn11pn22...pnVV(1.1) - 贝叶斯Unigram Model:在Unigram Model的基础上为
p⃗ 加入了先验分布,即需要先以一定的概率选出骰子。这里的先验分布选择了Multinomial式分布的共轭先验分布Dirichlet分布,Dirichlet分布的一般表现形式如下:Dir(p⃗ |α⃗ )=Γ(∑Kk=1αk)∏Kk=1Γ(αk)∏k=1Kpnkk(1.2)
其中Γ(x)=∫∞0tx−1e−tdt=(n−1)! 为gamma函数。
共轭先验分布的本质为:Dirichlet先验+Multinomial分布—>后验分布也为Dirichlet分布 即:Dir(p⃗ |α⃗ )+Mult(m⃗ )=Dir(p⃗ |α⃗ +m⃗ )(1.3)
关于gamma函数的更多性质以及Dirichlet先验和Multinomial分布的共轭关系的证明待有空专门整理。
1.1.2 主题相关模型
文本的生成过程过程依然是生成单词序列的过程,只不过引入主题后,需要先确定主题,再根据主题生成单词的过程,这也更符合人类平常的语言习惯。
- PLSA模型:先投掷一个K面的doc-topic骰子,确定主题编号z,再从K个V面的topic-word骰子中选择编号为z的骰子投掷,得到一个单词,重复该过程生成文档。假设文档
dm={ω1,ω2,...,ωn} 都有其主题序列概率向量θ⃗ m ,编号为k的topic-word骰子的单词序列概率向量为φ⃗ k ,则文档dm 中每个单词w的生成概率为:p(ω|dm)=∑z=1Kp(ω|z)p(z|dm)=∑z=1Kφzwθmz(1.4)
则整篇文档dm 的生成概率为:p(ω⃗ |dm)=∏i=1n∑z=1Kp(ω|z)p(z|dm)=∏i=1n∑z=1Kφzwθmz(1.5) - LDA模型:相当于是为 PLSA模型中的主题序列向量
θ⃗ m 和单词序列向量为φ⃗ k 都加入了先验分布。从Unigram Model的投掷过程可知θ⃗ m和φ⃗ k 都是满足Multinomial分布的,因此先验分布的选择都是Dirichlet分布,分别为α⃗ 和β⃗ 。
1.1.3 LDA物理图模型
Figure 1. LDA物理图模型
根据该概率图可以得到生成文档的过程可表述为:
(1) 从Dirichlet分布
(2) 从主题的Multinomial分布
(3) 从Dirichlet分布
(4) 从词语的Multinomial分布
其中
同理
1.2 隐藏变量求解
PLSA模型和LDA模型模型都引入了隐藏变量,即主题
1.2.1 EM算法
EM算法的本质也是求解最大似然估计,只不过这里的似然函数是可观察样本
(1) E步:在参数确定的情况下(第一次迭代时会初始化参数)通过Jensen不等式【对凹函数而言,期望的函数值大于等于函数值的期望】找到
(2) M步:通过参数调整使得下界不断上升以此来逼近
详细的算法求解可参考《JerryLead—EM算法》。PLSA就是采用EM算法去求解“文档-主题”矩阵
1.2.2 Gibbs Sampling
LDA模型中隐藏变量的参数求解可以采用变分EM算法或Gibbs采样,Spark MLlib中采用的变分EM算法和在线学习的方法,其中EM LDA基于gibbs采样原理来估计参数的。这里只是简单概述下Gibbs采样的本质
Gibbs 采样解决的主要问题就是如何采样(采样主题及该主题下的单词)使得的采样的样本符合其联合概率的
在《LDA数学八卦》中给出了一个关于后代经济状态迭代的例子,即在给定初代收入阶级(下层,中层,上层)分布
对于给定的概率分布p(x),我们希望能有一种便捷的生成其样本。如果我们能构造一个转移矩阵为
其中:
GraphX基础知识
Spark GraphX的原理细节方面的知识可参考shijinkui/spark_graphx_analyze,这里只介绍GraphX本算法设计到的基本操作的概念
2.1 图的构造
从上面的分析可知,我们需要维护包含隐藏变量(主题)属性的
Figure 2. LDA模型生成图
其中
2.2 图的分布式存储
Grpahx中构建完图时默认会使用“partitionBy(PartitionStrategy.EdgePartition1D)”方法采用vertexcut(点分割)方式存储图(见Figure 3.)。这种存储方式特点 是任何一条边只会出现在一台机器上,每个点有可能分布到不同的机器上。当点被分割到不同机器上时,是相同的镜像,但是有一个点作为主点(master), 其他的点作为虚点(ghost),当点B的数据发生变化时,先更新点B的master的数据,然后将所有更新好的数据发送到B的ghost所在的所有机 器,更新B的ghost。这样做的好处是在边的存储上是没有冗余的,而且对于某个点与它的邻居的交互操作,只要满足交换律和结合律,比如求邻居权重的和, 求点的所有边的条数这样的操作,可以在不同的机器上并行进行,只要把每个机器上的结果进行汇总就可以了,网络开销也比较小。代价是每个点可能要存储多份, 更新点要有数据同步开销。
Figure 3. 点分割存储
2.3 图的聚合操作
LDA算法的迭代过程中需要对每篇文档的每个单词重新选取主题,即重新计算
Figure 4. 图的聚合操作
Spark MLlib实践
这里先给出官方使用LDA模型的例子JavaLatentDirichletAllocationExample.java中的部分关键代码,以此作为源码分析的切入点:
其中line33~46是加载原始文档—词频数据并将每行数据解析为Vector的结构。其中原始数据如下,每一行表示一个文档,每一列表示一个单词,每一个元素
line60中首先new LDA()中指定了默认使用变分EM算法(EMLDA)来学习模型,setK(3)指定了主题个数K为3(不指定的话默认为10),也对应上文中提到的最终构成的LDA图中文档顶点和图节点的数据内容是3维向量
org.apache.spark.mllib.clustering.LDA.scaladef this() = this(k = 10, maxIterations = 20, docConcentration = Vectors.dense(-1), topicConcentration = -1, seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer)
下面重点分析run方法,这里面完成了LDA模型图的构建以及参数的迭代。
org.apache.spark.mllib.clustering.LDA.scaladef run(documents: RDD[(Long, Vector)]): LDAModel = { val state = ldaOptimizer.initialize(documents, this) var iter = 0 val iterationTimes = Array.fill[Double](maxIterations)(0) while (iter < maxIterations) { val start = System.nanoTime() state.next() val elapsedSeconds = (System.nanoTime() - start) / 1e9 iterationTimes(iter) = elapsedSeconds iter += 1 } state.getLDAModel(iterationTimes) }
(1) 其中initialize()方法完成图的构建(文档节点,词顶点以及文档节点指向词顶点的边),这里给出部分核心代码
org.apache.spark.mllib.clustering.EMLDAOptimizer.scalaoverride private[clustering] def initialize( docs: RDD[(Long, Vector)], lda: LDA): EMLDAOptimizer = { val docConcentration = lda.getDocConcentration //也即先验参数 alpha val topicConcentration = lda.getTopicConcentration //也即先验参数 beta val k = lda.getK //也即主题个数 ...... ...... this.docConcentration = if (docConcentration == -1) (50.0 / k) + 1.0 else docConcentration this.topicConcentration = if (topicConcentration == -1) 1.1 else topicConcentration val randomSeed = lda.getSeed // 为每篇文档中的每个词创建一条(文档->单词)的边,该边以三元组(文档Id,单词Id,词频)的形式存储 val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => // filter()方法过滤掉词频为0的单词 termCounts.asBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => /*这里的term2index为单词编号,在词频矩阵作为输入条件的情况下,每一列代表一个单词的词频信息,因此为列编号即为单词编号,方法实现里是按照(-1,-2,...)依次编号矩阵的列*/ Edge(docID, term2index(term), cnt) } } /* 根据边生成顶点,edge.srcId边的起点即文档顶点,edge.dstId边的中点即词顶点(顶点数据均为为K维数据),edge.attr边数据即词频信息。*/ val docTermVertices: RDD[(VertexId, TopicCounts)] = { val verticesTMP: RDD[(VertexId, TopicCounts)] = edges.mapPartitionsWithIndex { case (partIndex, partEdges) => val random = new Random(partIndex + randomSeed) partEdges.flatMap { edge => //先随机生成K维的主题分布参数 val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) //为K维随机参数*词频信息作为这条边为其两个顶点分配的K维顶点数据信息 val sum = gamma * edge.attr Seq((edge.srcId, sum), (edge.dstId, sum)) } } verticesTMP.reduceByKey(_ + _)//reduce操作汇总所有边分配给顶点的数据信息 } //根据顶点和边信息生成图,并采用点分割模式进行分布式存储 this.graph = Graph(docTermVertices, edges).partitionBy(PartitionStrategy.EdgePartition1D) this.k = k this.vocabSize = docs.take(1).head._2.size this.checkpointInterval = lda.getCheckpointInterval this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( checkpointInterval, graph.vertices.sparkContext) this.graphCheckpointer.update(this.graph) this.globalTopicTotals = computeGlobalTopicTotals() //计算所有词的主题分布概率和 this }private def computeGlobalTopicTotals(): TopicCounts = { val numTopics = k /*filter方法从所有的图顶点中过滤出词顶点,在前面生成边过程中对单词编号(term2index)时,是按照(-1,-2,...)编号的,所以这里的过滤其实就是查看顶点编号是否小于0*/ graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) }
(2) 接着在循环体里面调用next()方法,迭代计算模型参数
org.apache.spark.mllib.clustering.EMLDAOptimizer.scalaoverride private[clustering] def next(): EMLDAOptimizer = { val eta = topicConcentration // 也即先验参数 beta val W = vocabSize // 单词总数,也即输入矩阵中列数 val alpha = docConcentration // 也即先验参数 alpha val N_k = globalTopicTotals //所有词的主题分布概率和 val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = (edgeContext) => { // N_{wj} 词汇w在文档中的频次 val N_wj = edgeContext.attr // E-STEP: 计算gamma_{wjk}:词汇w在文档j中分配给主题k的概率,参考公式(2.9) val scaledTopicDistribution: TopicCounts = computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj //将计算出来的gamma_{wjk}发送消息给边的源顶点和目的顶点 edgeContext.sendToDst((false, scaledTopicDistribution)) edgeContext.sendToSrc((false, scaledTopicDistribution)) } // 顶点合并消息,用于Map阶段,每个分区中每个点收到的消息合并,以及reduce阶段,合并不同分区的消息 val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = (m0, m1) => { val sum = if (m0._1) { m0._2 += m1._2 } else if (m1._1) { m1._2 += m0._2 } else { m0._2 + m1._2 } (true, sum) } // M-STEP: 每个节点通过收集邻居数据来更新主题权重数据 val docTopicDistributions: VertexRDD[TopicCounts] = graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) .mapValues(_._2) // 根据最新顶点数据更新图 val newGraph = Graph(docTopicDistributions, graph.edges) graph = newGraph graphCheckpointer.update(newGraph) globalTopicTotals = computeGlobalTopicTotals() this }
最后给出这个Example的输出,主题—词矩阵,如下:
- Spark MLlib — EMLDA
- Spark MLlib — Word2Vec
- Spark MLlib — Word2Vec
- Spark MLlib
- spark MLlib
- Spark MLLib
- Spark MLlib
- Spark-MLlib实例——逻辑回归
- Spark-MLlib实例——决策树
- Spark-MLlib实例——决策树
- Spark MLlib机器学习—封面
- [MLLib]一、Spark MLLib介绍
- Spark MLlib学习笔记之二——Spark Mllib矩阵向量
- spark mllib初探练习
- spark-mllib-TFIDF实现
- Spark MLlib SVM算法
- Spark MLlib FPGrowth算法
- Spark MLlib Statistics统计
- hue3.11主页面报错500
- HYSBZ 3531 旅行
- websocket getAsyncRemote()和getBasicRemote()区别
- Activity界面的加载和绘制
- [生存志] 第46节 秦穆公任贤霸西戎
- Spark MLlib — EMLDA
- 导致PHP程序死循环的一个原因
- div常用属性
- Android 即时音视频解决方案1——环信
- Java基础-IO流6 流的操作规律
- Kafka
- oracle数据库做定时任务(插入) 笔记
- iOS-常用工具库代码
- B. Spotlights