LDA Gibbs Sampling 的JAVA实现
来源:互联网 发布:淘宝卖男装的那些店好 编辑:程序博客网 时间:2024/05/20 10:23
原文地址:http://blog.csdn.net/yangliuy/article/details/8457329
本系列博文介绍常见概率语言模型及其变形模型,主要总结PLSA、LDA及LDA的变形模型及参数Inference方法。初步计划内容如下
第一篇:PLSA及EM算法
第二篇:LDA及Gibbs Samping
第三篇:LDA变形模型-Twitter LDA,TimeUserLDA,ATM,Labeled-LDA,MaxEnt-LDA等
第四篇:基于变形LDA的paper分类总结(bibliography)
第五篇:LDA Gibbs Sampling 的JAVA实现
第五篇 LDA Gibbs Sampling的JAVA 实现
在本系列博文的前两篇,我们系统介绍了PLSA, LDA以及它们的参数Inference 方法,重点分析了模型表示和公式推导部分。曾有位学者说,“做研究要顶天立地”,意思是说做研究空有模型和理论还不够,我们还得有扎实的程序code和真实数据的实验结果来作为支撑。本文就重点分析 LDA Gibbs Sampling的JAVA 实现,并给出apply到newsgroup18828新闻文档集上得出的Topic建模结果。
本项目Github地址 https://github.com/yangliuy/LDAGibbsSampling
1、文档集预处理
要用LDA对文本进行topic建模,首先要对文本进行预处理,包括token,去停用词,stem,去noise词,去掉低频词等等。当语料库比较大时,我们也可以不进行stem。然后将文本转换成term的index表示形式,因为后面实现LDA的过程中经常需要在term和index之间进行映射。Documents类的实现如下,里面定义了Document内部类,用于描述文本集合中的文档。
- package liuyang.nlp.lda.main;
- import java.io.File;
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.Map;
- import java.util.regex.Matcher;
- import java.util.regex.Pattern;
- import liuyang.nlp.lda.com.FileUtil;
- import liuyang.nlp.lda.com.Stopwords;
- /**Class for corpus which consists of M documents
- * @author yangliu
- * @blog http://blog.csdn.net/yangliuy
- * @mail yangliuyx@gmail.com
- */
- public class Documents {
- ArrayList<Document> docs;
- Map<String, Integer> termToIndexMap;
- ArrayList<String> indexToTermMap;
- Map<String,Integer> termCountMap;
- public Documents(){
- docs = new ArrayList<Document>();
- termToIndexMap = new HashMap<String, Integer>();
- indexToTermMap = new ArrayList<String>();
- termCountMap = new HashMap<String, Integer>();
- }
- public void readDocs(String docsPath){
- for(File docFile : new File(docsPath).listFiles()){
- Document doc = new Document(docFile.getAbsolutePath(), termToIndexMap, indexToTermMap, termCountMap);
- docs.add(doc);
- }
- }
- public static class Document {
- private String docName;
- int[] docWords;
- public Document(String docName, Map<String, Integer> termToIndexMap, ArrayList<String> indexToTermMap, Map<String, Integer> termCountMap){
- this.docName = docName;
- //Read file and initialize word index array
- ArrayList<String> docLines = new ArrayList<String>();
- ArrayList<String> words = new ArrayList<String>();
- FileUtil.readLines(docName, docLines);
- for(String line : docLines){
- FileUtil.tokenizeAndLowerCase(line, words);
- }
- //Remove stop words and noise words
- for(int i = 0; i < words.size(); i++){
- if(Stopwords.isStopword(words.get(i)) || isNoiseWord(words.get(i))){
- words.remove(i);
- i--;
- }
- }
- //Transfer word to index
- this.docWords = new int[words.size()];
- for(int i = 0; i < words.size(); i++){
- String word = words.get(i);
- if(!termToIndexMap.containsKey(word)){
- int newIndex = termToIndexMap.size();
- termToIndexMap.put(word, newIndex);
- indexToTermMap.add(word);
- termCountMap.put(word, new Integer(1));
- docWords[i] = newIndex;
- } else {
- docWords[i] = termToIndexMap.get(word);
- termCountMap.put(word, termCountMap.get(word) + 1);
- }
- }
- words.clear();
- }
- public boolean isNoiseWord(String string) {
- // TODO Auto-generated method stub
- string = string.toLowerCase().trim();
- Pattern MY_PATTERN = Pattern.compile(".*[a-zA-Z]+.*");
- Matcher m = MY_PATTERN.matcher(string);
- // filter @xxx and URL
- if(string.matches(".*www\\..*") || string.matches(".*\\.com.*") ||
- string.matches(".*http:.*") )
- return true;
- if (!m.matches()) {
- return true;
- } else
- return false;
- }
- }
- }
2 LDA Gibbs Sampling
文本预处理完毕后我们就可以实现LDA Gibbs Sampling。 首先我们要定义需要的参数,我的实现中在程序中给出了参数默认值,同时也支持配置文件覆盖,程序默认优先选用配置文件的参数设置。整个算法流程包括模型初始化,迭代Inference,不断更新主题和待估计参数,最后输出收敛时的参数估计结果。
包含主函数的配置参数解析类如下:
- package liuyang.nlp.lda.main;
- import java.io.File;
- import java.io.IOException;
- import java.util.ArrayList;
- import liuyang.nlp.lda.com.FileUtil;
- import liuyang.nlp.lda.conf.ConstantConfig;
- import liuyang.nlp.lda.conf.PathConfig;
- /**Liu Yang's implementation of Gibbs Sampling of LDA
- * @author yangliu
- * @blog http://blog.csdn.net/yangliuy
- * @mail yangliuyx@gmail.com
- */
- public class LdaGibbsSampling {
- public static class modelparameters {
- float alpha = 0.5f; //usual value is 50 / K
- float beta = 0.1f;//usual value is 0.1
- int topicNum = 100;
- int iteration = 100;
- int saveStep = 10;
- int beginSaveIters = 50;
- }
- /**Get parameters from configuring file. If the
- * configuring file has value in it, use the value.
- * Else the default value in program will be used
- * @param ldaparameters
- * @param parameterFile
- * @return void
- */
- private static void getParametersFromFile(modelparameters ldaparameters,
- String parameterFile) {
- // TODO Auto-generated method stub
- ArrayList<String> paramLines = new ArrayList<String>();
- FileUtil.readLines(parameterFile, paramLines);
- for(String line : paramLines){
- String[] lineParts = line.split("\t");
- switch(parameters.valueOf(lineParts[0])){
- case alpha:
- ldaparameters.alpha = Float.valueOf(lineParts[1]);
- break;
- case beta:
- ldaparameters.beta = Float.valueOf(lineParts[1]);
- break;
- case topicNum:
- ldaparameters.topicNum = Integer.valueOf(lineParts[1]);
- break;
- case iteration:
- ldaparameters.iteration = Integer.valueOf(lineParts[1]);
- break;
- case saveStep:
- ldaparameters.saveStep = Integer.valueOf(lineParts[1]);
- break;
- case beginSaveIters:
- ldaparameters.beginSaveIters = Integer.valueOf(lineParts[1]);
- break;
- }
- }
- }
- public enum parameters{
- alpha, beta, topicNum, iteration, saveStep, beginSaveIters;
- }
- /**
- * @param args
- * @throws IOException
- */
- public static void main(String[] args) throws IOException {
- // TODO Auto-generated method stub
- String originalDocsPath = PathConfig.ldaDocsPath;
- String resultPath = PathConfig.LdaResultsPath;
- String parameterFile= ConstantConfig.LDAPARAMETERFILE;
- modelparameters ldaparameters = new modelparameters();
- getParametersFromFile(ldaparameters, parameterFile);
- Documents docSet = new Documents();
- docSet.readDocs(originalDocsPath);
- System.out.println("wordMap size " + docSet.termToIndexMap.size());
- FileUtil.mkdir(new File(resultPath));
- LdaModel model = new LdaModel(ldaparameters);
- System.out.println("1 Initialize the model ...");
- model.initializeModel(docSet);
- System.out.println("2 Learning and Saving the model ...");
- model.inferenceModel(docSet);
- System.out.println("3 Output the final model ...");
- model.saveIteratedModel(ldaparameters.iteration, docSet);
- System.out.println("Done!");
- }
- }
LDA 模型实现类如下
- package liuyang.nlp.lda.main;
- /**Class for Lda model
- * @author yangliu
- * @blog http://blog.csdn.net/yangliuy
- * @mail yangliuyx@gmail.com
- */
- import java.io.BufferedWriter;
- import java.io.FileWriter;
- import java.io.IOException;
- import java.util.ArrayList;
- import java.util.Collections;
- import java.util.Comparator;
- import java.util.List;
- import liuyang.nlp.lda.com.FileUtil;
- import liuyang.nlp.lda.conf.PathConfig;
- public class LdaModel {
- int [][] doc;//word index array
- int V, K, M;//vocabulary size, topic number, document number
- int [][] z;//topic label array
- float alpha; //doc-topic dirichlet prior parameter
- float beta; //topic-word dirichlet prior parameter
- int [][] nmk;//given document m, count times of topic k. M*K
- int [][] nkt;//given topic k, count times of term t. K*V
- int [] nmkSum;//Sum for each row in nmk
- int [] nktSum;//Sum for each row in nkt
- double [][] phi;//Parameters for topic-word distribution K*V
- double [][] theta;//Parameters for doc-topic distribution M*K
- int iterations;//Times of iterations
- int saveStep;//The number of iterations between two saving
- int beginSaveIters;//Begin save model at this iteration
- public LdaModel(LdaGibbsSampling.modelparameters modelparam) {
- // TODO Auto-generated constructor stub
- alpha = modelparam.alpha;
- beta = modelparam.beta;
- iterations = modelparam.iteration;
- K = modelparam.topicNum;
- saveStep = modelparam.saveStep;
- beginSaveIters = modelparam.beginSaveIters;
- }
- public void initializeModel(Documents docSet) {
- // TODO Auto-generated method stub
- M = docSet.docs.size();
- V = docSet.termToIndexMap.size();
- nmk = new int [M][K];
- nkt = new int[K][V];
- nmkSum = new int[M];
- nktSum = new int[K];
- phi = new double[K][V];
- theta = new double[M][K];
- //initialize documents index array
- doc = new int[M][];
- for(int m = 0; m < M; m++){
- //Notice the limit of memory
- int N = docSet.docs.get(m).docWords.length;
- doc[m] = new int[N];
- for(int n = 0; n < N; n++){
- doc[m][n] = docSet.docs.get(m).docWords[n];
- }
- }
- //initialize topic lable z for each word
- z = new int[M][];
- for(int m = 0; m < M; m++){
- int N = docSet.docs.get(m).docWords.length;
- z[m] = new int[N];
- for(int n = 0; n < N; n++){
- int initTopic = (int)(Math.random() * K);// From 0 to K - 1
- z[m][n] = initTopic;
- //number of words in doc m assigned to topic initTopic add 1
- nmk[m][initTopic]++;
- //number of terms doc[m][n] assigned to topic initTopic add 1
- nkt[initTopic][doc[m][n]]++;
- // total number of words assigned to topic initTopic add 1
- nktSum[initTopic]++;
- }
- // total number of words in document m is N
- nmkSum[m] = N;
- }
- }
- public void inferenceModel(Documents docSet) throws IOException {
- // TODO Auto-generated method stub
- if(iterations < saveStep + beginSaveIters){
- System.err.println("Error: the number of iterations should be larger than " + (saveStep + beginSaveIters));
- System.exit(0);
- }
- for(int i = 0; i < iterations; i++){
- System.out.println("Iteration " + i);
- if((i >= beginSaveIters) && (((i - beginSaveIters) % saveStep) == 0)){
- //Saving the model
- System.out.println("Saving model at iteration " + i +" ... ");
- //Firstly update parameters
- updateEstimatedParameters();
- //Secondly print model variables
- saveIteratedModel(i, docSet);
- }
- //Use Gibbs Sampling to update z[][]
- for(int m = 0; m < M; m++){
- int N = docSet.docs.get(m).docWords.length;
- for(int n = 0; n < N; n++){
- // Sample from p(z_i|z_-i, w)
- int newTopic = sampleTopicZ(m, n);
- z[m][n] = newTopic;
- }
- }
- }
- }
- private void updateEstimatedParameters() {
- // TODO Auto-generated method stub
- for(int k = 0; k < K; k++){
- for(int t = 0; t < V; t++){
- phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta);
- }
- }
- for(int m = 0; m < M; m++){
- for(int k = 0; k < K; k++){
- theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
- }
- }
- }
- private int sampleTopicZ(int m, int n) {
- // TODO Auto-generated method stub
- // Sample from p(z_i|z_-i, w) using Gibbs upde rule
- //Remove topic label for w_{m,n}
- int oldTopic = z[m][n];
- nmk[m][oldTopic]--;
- nkt[oldTopic][doc[m][n]]--;
- nmkSum[m]--;
- nktSum[oldTopic]--;
- //Compute p(z_i = k|z_-i, w)
- double [] p = new double[K];
- for(int k = 0; k < K; k++){
- p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta) * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
- }
- //Sample a new topic label for w_{m, n} like roulette
- //Compute cumulated probability for p
- for(int k = 1; k < K; k++){
- p[k] += p[k - 1];
- }
- double u = Math.random() * p[K - 1]; //p[] is unnormalised
- int newTopic;
- for(newTopic = 0; newTopic < K; newTopic++){
- if(u < p[newTopic]){
- break;
- }
- }
- //Add new topic label for w_{m, n}
- nmk[m][newTopic]++;
- nkt[newTopic][doc[m][n]]++;
- nmkSum[m]++;
- nktSum[newTopic]++;
- return newTopic;
- }
- public void saveIteratedModel(int iters, Documents docSet) throws IOException {
- // TODO Auto-generated method stub
- //lda.params lda.phi lda.theta lda.tassign lda.twords
- //lda.params
- String resPath = PathConfig.LdaResultsPath;
- String modelName = "lda_" + iters;
- ArrayList<String> lines = new ArrayList<String>();
- lines.add("alpha = " + alpha);
- lines.add("beta = " + beta);
- lines.add("topicNum = " + K);
- lines.add("docNum = " + M);
- lines.add("termNum = " + V);
- lines.add("iterations = " + iterations);
- lines.add("saveStep = " + saveStep);
- lines.add("beginSaveIters = " + beginSaveIters);
- FileUtil.writeLines(resPath + modelName + ".params", lines);
- //lda.phi K*V
- BufferedWriter writer = new BufferedWriter(new FileWriter(resPath + modelName + ".phi"));
- for (int i = 0; i < K; i++){
- for (int j = 0; j < V; j++){
- writer.write(phi[i][j] + "\t");
- }
- writer.write("\n");
- }
- writer.close();
- //lda.theta M*K
- writer = new BufferedWriter(new FileWriter(resPath + modelName + ".theta"));
- for(int i = 0; i < M; i++){
- for(int j = 0; j < K; j++){
- writer.write(theta[i][j] + "\t");
- }
- writer.write("\n");
- }
- writer.close();
- //lda.tassign
- writer = new BufferedWriter(new FileWriter(resPath + modelName + ".tassign"));
- for(int m = 0; m < M; m++){
- for(int n = 0; n < doc[m].length; n++){
- writer.write(doc[m][n] + ":" + z[m][n] + "\t");
- }
- writer.write("\n");
- }
- writer.close();
- //lda.twords phi[][] K*V
- writer = new BufferedWriter(new FileWriter(resPath + modelName + ".twords"));
- int topNum = 20; //Find the top 20 topic words in each topic
- for(int i = 0; i < K; i++){
- List<Integer> tWordsIndexArray = new ArrayList<Integer>();
- for(int j = 0; j < V; j++){
- tWordsIndexArray.add(new Integer(j));
- }
- Collections.sort(tWordsIndexArray, new LdaModel.TwordsComparable(phi[i]));
- writer.write("topic " + i + "\t:\t");
- for(int t = 0; t < topNum; t++){
- writer.write(docSet.indexToTermMap.get(tWordsIndexArray.get(t)) + " " + phi[i][tWordsIndexArray.get(t)] + "\t");
- }
- writer.write("\n");
- }
- writer.close();
- }
- public class TwordsComparable implements Comparator<Integer> {
- public double [] sortProb; // Store probability of each word in topic k
- public TwordsComparable (double[] sortProb){
- this.sortProb = sortProb;
- }
- @Override
- public int compare(Integer o1, Integer o2) {
- // TODO Auto-generated method stub
- //Sort topic word index according to the probability of each word in topic k
- if(sortProb[o1] > sortProb[o2]) return -1;
- else if(sortProb[o1] < sortProb[o2]) return 1;
- else return 0;
- }
- }
- }
程序的实现细节可以参考我在程序中给出的注释,如果理解LDA Gibbs Sampling的算法流程,上面的代码很好理解。其实排除输入输出和参数解析的代码,标准LDA 的Gibbs sampling只需要不到200行程序就可以搞定。当然,里面有很多可以考虑优化和变形的地方。
还有com和conf目录下的源文件分别放置常用函数和配置类,完整的JAVA工程见Github https://github.com/yangliuy/LDAGibbsSampling
3 用LDA Gibbs Sampling对Newsgroup 18828文档集进行主题分析
下面我们给出将上面的LDA Gibbs Sampling的实现Apply到Newsgroup 18828文档集进行主题分析的结果。 我实验时用到的数据已经上传到Github中,感兴趣的朋友可以直接从Github中下载工程运行。 我在Newsgroup 18828文档集随机选择了9个目录,每个目录下选择一个文档,将它们放置在data\LdaOriginalDocs目录下,我设定的模型参数如下
- alpha 0.5
- beta 0.1
- topicNum 10
- iteration 100
- saveStep 10
- beginSaveIters 80
即设定alpha和beta的值为0.5和0.1, Topic数目为10,迭代100次,从第80次开始保存模型结果,每10次保存一次。
经过100次Gibbs Sampling迭代后,程序输出10个Topic下top的topic words以及对应的概率值如下
我们可以看到虽然是unsupervised learning, LDA分析出来的Topic words还是非常make sense的。比如第5个topic是宗教类的,第6个topic是天文类的,第7个topic是计算机类的。程序的输出还包括模型参数.param文件,topic-word分布phi向量.phi文件,doc-topic分布theta向量.theta文件以及每个文档中每个单词分配到的主题label的.tassign文件。感兴趣的朋友可以从Github https://github.com/yangliuy/LDAGibbsSampling 下载完整工程自己换用其他数据集进行主题分析实验。 本程序是初步实现版本,如果大家发现任何问题或者bug欢迎交流,我第一时间在Github修复bug更新版本。
4 参考文献
[1] Christopher M. Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics). Springer-Verlag New York, Inc., Secaucus, NJ, USA, 2006.
[2] Gregor Heinrich. Parameter estimation for text analysis. Technical report, 2004.
[3] Wang Yi. Distributed Gibbs Sampling of Latent Topic Models: The Gritty Details Technical report, 2005.
[4] Wayne Xin Zhao, Note for pLSA and LDA, Technical report, 2011.
[5] Freddy Chong Tat Chua. Dimensionality reduction and clustering of text documents.Technical report, 2009.
[6] Jgibblda, http://jgibblda.sourceforge.net/
[7]David M. Blei, Andrew Y. Ng, and Michael I. Jordan. 2003. Latent dirichlet allocation. J. Mach. Learn. Res. 3 (March 2003), 993-1022.- LDA Gibbs Sampling 的JAVA实现
- Gibbs Sampling实现LDA
- 概率语言模型及其变形系列(5)-LDA Gibbs Sampling 的JAVA实现
- 概率语言模型及其变形系列(5)-LDA Gibbs Sampling 的JAVA实现
- [未读] 概率语言模型及其变形系列(5)-LDA Gibbs Sampling 的JAVA实现
- Lda gibbs sampling --- python
- Gibbs sampling -- batch LDA
- LDA Gibbs Sampling公式推导
- 浅谈gibbs sampling(LDA实验)
- LDA-math-MCMC 和 Gibbs Sampling
- LDA-math-MCMC 和 Gibbs Sampling
- LDA-math-MCMC 和 Gibbs Sampling
- LDA-math-MCMC 和 Gibbs Sampling
- LDA-math-MCMC 和 Gibbs Sampling
- LDA-math-MCMC 和 Gibbs Sampling
- LDA-math-MCMC 和 Gibbs Sampling
- LDA-math-MCMC 和 Gibbs Sampling
- LDA-math-MCMC 和 Gibbs Sampling
- Java线程 - 后台线程 daemon thread
- 空腹喝牛奶 解密食品不能空腹吃的传言
- hibernate中的关系映射
- [leetcode] Single Number II
- jquery自定义滑动门使用div,非li
- LDA Gibbs Sampling 的JAVA实现
- 3D游戏之路--导言
- js 控制 style 大全
- 经典的机器学习方面源代码库(非常全,数据挖掘,计算机视觉,模式识别,信息检索相关领域都适用的了)
- JAVA学习笔记一(JAVA输出环境变量)
- Codeforces 2
- NEFU 115 斐波那契的整除
- 数据库连接池简析
- 如何调用DLL中的函数