google PLDA + 实现原理及源码分析

来源:互联网 发布:informix数据库论坛 编辑:程序博客网 时间:2024/05/11 22:39

LDA背景

LDA(隐含狄利克雷分布)是一个主题聚类模型,是当前主题聚类领域最火、最有力的模型之一,它能通过多轮迭代把特征向量集合按主题分类。目前,广泛运用在文本主题聚类中。
LDA的开源实现有很多。目前广泛使用、能够分布式并行处理大规模语料库的有微软的LightLDA,谷歌plda、plda+,sparkLDA等等。下面介绍这3种LDA:

LightLDA依赖于微软自己实现的multiverso参数服务器,服务器底层使用mpi或zeromq发送消息。LDA模型(word-topic矩阵)由参数服务器保存,它为文档训练进程提供参数查询、更新服务。

plda、plda+使用mpi消息通信,将mpi进程分为word、doc俩部分。doc进程训练文档,word进程为doc进程提供模型的查询、更新功能。

spark LDA有两种实现:1.基于gibbs sampling原理和使用GraphX实现的版本(即spark文档上所说的EMLDAOptimizer and DistributedLDAModel),2.基于变分推断原理的实现的版本(即spark文档上的OnlineLDAOptimizer and LocalLDAModel)。Spark LDA的介绍请参见这里。

LightLDA,plda、plda+,spark LDA比较

论能够处理预料库的规模大小,LihgtLDA要远远好于plda和spark LDA
经过测试,在10个服务器(8核40GB)集群规模下:
LihgtLDA能够处理上亿文档、百万词汇的语料库,能够训练上百万主题数。这样的处理能力使得LihgtLDA能够轻松训练绝大多数语料库。微软号称使用几十机器的集群便能训练Bing搜索引擎爬下数据的十分之一。
相对于LihgtLDA ,plda+能够处理规模小的多,上限是:词汇数目*主题数(模型大小) < 5亿。当语料库规模达到上限后,mpi集群会因内存不够而终止,或因为内存数据频繁切换,迭代速度十分缓慢。虽然plda+对语料库的词汇数目和训练的主题数目很敏感,但对文档的规模并不是很敏感,在词汇数目和主题数目较小的情况下,1000万级别的文档也能够轻松解决。
spark LDA的GraphX版处理规模衡量标准是图的顶点数据,即(文档数 + 词汇数目)*主题数目,上限是 文档数*主题数 < 50亿(由于词汇数目相对于文档数目往往较小,近似等于 文档数*主题数)。当超过这个规模后,spark集群进入假死状态。不停有节点出现OOM,直至任务以失败告终。
变分推断实现的spark LDA瓶颈是 词汇数目*主题数目,这个值也就是我们所说的模型大小,上限约1亿。为什么存在这个瓶颈呢?是因为变分推断的实现过程中,模型使用矩阵本地存储,各个分区计算模型的部分值,然后在driver上将矩阵reduce叠加。当模型过大,driver节点的内存就无法承受各个分区发过来的模型。
收敛速度上,LightLDA要远快于plda、plda+和spark LDA。小规模语料库(30万文档,10万词,1000主题)测试,LightLDA : plda+ : spark LDA(graphx) = 1:4:50
为什么各种LDA的能够处理语料库规模的衡量标准不一样呢?这与它们的实现方式有关,不同的LDA有不同的瓶颈,我们这里单讲plda+的源码解读,其他lda后续介绍

plda+介绍

plda+是LDA的并行C++实现,由谷歌公司开发,它分布式基础是MPI,使用高度优化的Gibbs sampling算法训练文档。
这里写图片描述
如图所示,plda+将mpi进程组分为2部分——word进程和doc进程。

word进程

word进程存储plda+的模型,使用分布是存储方式,每个进程只负责模型的一部分。LDA的模型指的是word-topic矩阵(矩阵大小=词汇数目x主题数目),矩阵每行表示语料库中一word在各个topic中出现的次数。实现上,每行word-topic由向量或数组表示。word进程负责为doc进程提供word-topic模型参数(即矩阵中的一行,word的各topic出现次数),响应doc进程发送过来的模型更新消息。它的角色就相当于一个参数服务器。

doc进程

doc进程是plda+存储文档的地方,也是训练文档的地方。也采用分布式存储方式,每个进程只持有语料库的一部分文档。另外,doc进程还分布是存储doc-topic矩阵,doc-topic矩阵(矩阵大小=文档数目x主题数目)描述语料库各文档doc中的所有词在各个topic下的数目。doc进程从word进程获取word-topic参数和global_topic参数(每个主题拥有的词的数目,由word-topic矩阵按行叠加),依据gibbs sampling算法为每个词的重新选取主题,将词的主题选取情况发送消息给word进程,通知其更新模型。

doc进程主要由3部分组成:

  • 文档集合
    文档由分布到该doc进程的所有文档组成,各文档记录了自己的词频信息和自己的各个词主题选取信息。
  • local word
    local word表示doc进程中所拥有的文档的词汇集合。进程建立了词到文档的反转索引word_inverted_index数据结构,能够使用word来遍历所有拥有该词的文档。
  • route(路由表)
    route为每个local word记录了一个mpi进程号,这个进程号即word进程的编号,表示该local word对应的word-topic模型由这个word进程负责。有了route,doc进程发送word-topic请求和更新消息时便知道往哪个word进程发送了.

MPI消息

plda+中总是doc进程向word进程主动发送请消息,word进程响应doc进程的请求消息,不存在其他的消息通信方式,如doc进程和doc进程之间、word进程和word进程之间,就不存在消息通信。

消息通信的类型有以下几种:

  • PLDAPLUS_TAG_FETCH
    doc进程向word进程发起word-topic参数请求。消息通信的方式是:请求-应答机制,word进程收到请求后向doc进程发送数据。所有的消息通信采用异步发送方式,doc进程与word进程发送消息后无需等待,继续做其他的事。
  • PLDAPLUS_TAG_FETCH_GLOBAL
    doc进程向word进程请求global_topic参数
  • PLDAPLUS_TAG_UPDATE
    doc进程向word进程发送模型更新消息,消息中包含doc进程local word中某个词的主题变化情况。
  • PLDAPLUS_TAG_DONE
    doc进程通知word进程训练结束,word进程退出消息等待的主循环。这类消息在最后一轮迭代完毕后,doc进程才发送。

plda+初始化

plda+要求在集群各个节点放置一份完整的语料库文档,各个进程从完整语料库中抽取文档和词来初始化一些重要的数据结构。由前面介绍可知,word进程和doc进程所需的数据并不相同,因而他们的初始化行为也不一样。好在mpi能够根据进程号来判断,有区别的让不同进程执行不同的代码。下面是主要的初始化步骤:

所有进程

  1. 建立word_index_map
    word_index_map是语料库词汇到索引的映射结构(c++ map),实现上是先将词汇按字符串顺序排序,把词汇映射到序号。之后,使用索引来代表词汇,由于处理int比string效率要高的多,这个做法可以提升效率。
  2. 建立word_pw_map (路由表)
    word_pw_map是local word到word进程的映射,就是我们上述的route结构。

doc进程

  1. 将语料库中的文档进行按doc进程轮流分配,分配完毕后,便确定了各doc进程拥有的文档集合local word
  2. 为文档中的词随机选择主题,形成doc-topic
  3. 将local word的主题初始主题情况发送给word进程,通知其更新模型。

Word进程

  1. 按照route路由表为word进程分配word,分配完毕后各进程便拥有各自的local word,进行编号,形成本地索引。建立global_local_word_index_map_,实现语料库中词全局索引到进程中本地索引之间的映射
  2. 为本地的词建立空的word-topic模型,其初始值为0
  3. 进行listen,listen是word进程接下来一直进行的事,它在不停地循环等待doc进程的消息,直到接收到所有doc进程的PLDAPLUS_TAG_DONE消息后才退出
  4. listen在初始化阶段,word进程主要接收的是doc进程发送过来的模型更新消息,形成初始模型。在后续迭代阶段,便响应doc进程的各种消息。

word进程listen实现

就像大多数服务器程序逻辑一样,listen不断执行循环,等待消息,响应消息…

do {    MPI_Recv(recv_buf, num_topics_t, MPI_LONG_LONG,             MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);    int tag = status.MPI_TAG;    int source = status.MPI_SOURCE;    switch(tag & 3) {   // get the last two bits      case PLDAPLUS_TAG_FETCH : {        MPI_Wait(&req, &status);        map<int, int>::iterator iter =            global_local_word_index_map_.find(tag >> PLDAPLUS_TAG_LENGTH);        if(iter != global_local_word_index_map_.end()) {          const TopicCountDistribution&   topic_word =  //请求词的topic参数              GetWordTopicDistribution(iter->second); //将请求的词转为本地索引          topic_word.replicate(send_buf);        }        MPI_Isend(send_buf, num_topics_t, MPI_LONG_LONG,                  source, tag, MPI_COMM_WORLD, &req); //异步发送消息        break;      }      case PLDAPLUS_TAG_FETCH_GLOBAL : {                ComputeLocalWordLlh(); //tlz        if(first_flag) {          first_flag = false;        } else {          MPI_Wait(&req, &status);        }        const TopicCountDistribution& global_topic =            GetGlobalTopicDistribution();        global_topic.replicate(send_buf);        MPI_Isend(send_buf, num_topics_t, MPI_LONG_LONG,                  source, tag, MPI_COMM_WORLD, &req);        break;      }      case PLDAPLUS_TAG_UPDATE : {        int word_index = global_local_word_index_map_[tag >> PLDAPLUS_TAG_LENGTH];        for(int k = 0; k < num_topics_t; ++k) {          IncrementTopic(word_index, k, recv_buf[k]);  //更新模型        }        break;      }      case PLDAPLUS_TAG_DONE : {        ++count_done;  //累加doc进程发来的PLDAPLUS_TAG_DONE消息        break;      }      default : {        // tag error      }    }  } while(count_done < pdnum);  //收到所有doc进程的PLDAPLUS_TAG_DONE退出listen

word进程就像参数服务器,不停地为doc进程提供word-topic和global-topic模型参数

doc进程的词优先顺序训练文档

plda+仍然使用传统的gibbs sampling算法,但它在训练顺序上进行了大胆的创新。
这里写图片描述
原始的文档训练采用文档优先顺序,即为语料库中每篇文档里的每一个词使用gibbs sampling确定新的主题。

plda+为doc进程中的文档建立了词-文档反转索引(word_inverted_index),local word中的每个词能够索引一系列含有该词的文档。plda+将训练过程该为local word中的每一个词对应的每篇文档进行gibbs sampling,为该文档中该词选取新的主题。
下面是词(本地索引是local_word_index)训练代码:

    // Sample for word local_word_index    for(list<InvertedIndex*>::iterator iter = pldaplus_corpus->word_inverted_index[local_word_index].begin();        iter != pldaplus_corpus->word_inverted_index[local_word_index].end(); ++iter) {      SampleNewTopicForWordInDocumentWithDistributions(          (*iter)->word_index_in_document,          (*iter)->document_ptr, train_model,          topic_word, global_topic, delta_topic);    }

迭代器iter是InvertedIndex结构的指针,它会遍历指向local_word_index对应的所有InvertedIndex结构,InvertedIndex结构会记录local_word_index对应的文档(document_ptr),以及该词在文档中的编号(word_index_in_document)。
gibbs sampling过程会为该文档该词选择新的主题,它其实并不复杂,只是依据采用公式通过word-topic(代码中是topic_word),global_topic,doc-topic参数,该词新的主题,并把主题更新信息记录在delta_topic数组中。

plda+按词优先训练文档中的词有以下几个优势:
1. 减少了local word主题更新信息的存储,当local_word_index对应的所有文档处理完毕,doc进程为该词发送模型更新消息。训练该词对应文档的过程只需要一个delta_topic数组的空间存储即可。若按文档优先顺序来sampling,要将取样结果更新到模型必然要经历下列方式之一:每取样一个词,就发送模型更新消息,这会导致大量的通信。一篇文档训练完毕或所有文档训练完毕才发送更新模型消息,这需要记录所有词的主题更新信息,因而会带来大量存储开销。
2. 模型更新速度适宜,文档优先顺序中所有文档处理完毕再发送消息,虽然发送的更新消息量非常小,但对模型更新来说,这是一个大同步,会导致模型的收敛速度便慢,甚至会出现抖动。
3. 词优先训练文档每处理一个word发送更新消息,使得消息发送(异步发送)与文档训练交叉执行,使得通信与计算重叠,提高了系统的吞吐率。

global_topic参数获取

void PLDAPLUSModelForPd::GetGlobalTopic(int64* global_topic) {  int num_topics_t = num_topics();  MPI_Status  status;  MPI_Send(buf_, 0, MPI_LONG_LONG, 0, PLDAPLUS_TAG_FETCH_GLOBAL, MPI_COMM_WORLD);  MPI_Recv(global_topic, num_topics_t, MPI_LONG_LONG,           0, PLDAPLUS_TAG_FETCH_GLOBAL, MPI_COMM_WORLD, &status);  for(int dest = 1; dest < pwnum_; ++dest) {    MPI_Send(buf_, 0, MPI_LONG_LONG,             dest, PLDAPLUS_TAG_FETCH_GLOBAL, MPI_COMM_WORLD);    MPI_Recv(buf_, num_topics_t, MPI_LONG_LONG,             dest, PLDAPLUS_TAG_FETCH_GLOBAL, MPI_COMM_WORLD, &status);    for(int k = 0; k < num_topics_t; ++k) {      global_topic[k] += buf_[k];    }  }}

doc进程依次向各个word进程发送PLDAPLUS_TAG_FETCH_GLOBAL消息,将word进程响应的局部global_topic累加,形成global_topic。为什么说是局部呢?因为global_topic是每个主题拥所有词的总数目,每个word进程只能统计它自己拥有的那一部分模型。
由于进程数目本来较小,plda+为了实现简单,doc进程使用同步方式进行通信。

word-topic参数获取和消息的异步机制

  // Init fetching pool  for(int i = 0; i < num_words_t && pool_size < PLDAPLUS_MAX_POOL_SIZE; ++i) {    model_pd_->GetTopicWordNonblocking(i, recv_buf + pool_size * num_topics_t,                                       request_pool + pool_size);    word_index_pool[pool_size] = i;    ++pool_size;  }  for(int i = pool_size; i < num_words_t; ++i) {    // Wait for fetching any topic word distribution    MPI_Waitany(PLDAPLUS_MAX_POOL_SIZE, request_pool, &request_index, &status);    // Redirect topic word pointer    topic_word = recv_buf + request_index * num_topics_t;    memset(delta_topic, 0, sizeof(*delta_topic) * num_topics_t);    int local_word_index = word_index_pool[request_index];        model_pd_->UpdateWordCoverTopic(local_word_index, topic_word);    // Sample for word local_word_index    for(list<InvertedIndex*>::iterator iter = pldaplus_corpus->word_inverted_index[local_word_index].begin();        iter != pldaplus_corpus->word_inverted_index[local_word_index].end(); ++iter) {      SampleNewTopicForWordInDocumentWithDistributions(          (*iter)->word_index_in_document,          (*iter)->document_ptr, train_model,          topic_word, global_topic, delta_topic);    }    // Update for word local_word_index    model_pd_->UpdateTopicWord(local_word_index, delta_topic);    for(int k = 0; k < num_topics_t; ++k) {      global_topic[k] += delta_topic[k];    }    // Fetch next topic word distribution    model_pd_->GetTopicWordNonblocking(i, topic_word, request_pool + request_index);    word_index_pool[request_index] = i;  }

doc进程使用异步的MPI消息请求word_topic参数,异步方式即请求后不原地等待word进程的响应。plda+的实现如下:
1. plda+在为local word请求word-topic参数时,最开始发出一池子(100个)的word_topic请求,将他们放到消息池中,监控池子中响应的到来。
2. 每到来一个响应,便根据响应消息带来的word_topic参数训练该词对应的一系列文档。训练完毕后,发送该词的模型更新消息。该词处理完毕,在消息池中占的位置也可被占用。
3. 发送下一个local word的word-topic参数请求,用前一个词在消息池中的位置来存放请求的消息,进行监控。直到local word全部训练完毕。

plda+ loglikelihood计算问题

使用过plda+的同学可能发现,plda+的loglikelihood的值竟然是随迭代次数增加而递减的,这严重不符合likelihood的定义,随着迭代加深,似然函数的值应该不断逼近最大值。
笔者依照LihgtLda的计算方式重新实现了plda+的loglikelihood计算。参见:https://github.com/tanglizhe1105/plda
我们把loglikelihood的计算拆成2部分:doc-topic矩阵和word-topic模型矩阵,其中word-topic为了方便计算又分为了word loglikelihood和normalized loglikelihood。拆分的理由请见:https://github.com/Microsoft/lightlda/issues/9
因而,word-topic的loglikelihood为 word +normalized loglikelihood。
总的loglikelihood = doc + word + normalized loglikelihood。

作者介绍

唐黎哲,国防科学技术大学 并行与分布式计算国家重点实验室(PDL)硕士,从事spark、图计算、LDA(主题分类)研究,欢迎交流,请多指教。
邮箱:tanglizhe1105@qq.com

1 0