LDA Gibbs Sampling 的JAVA实现

来源:互联网 发布:淘宝卖男装的那些店好 编辑:程序博客网 时间:2024/05/20 10:23

原文地址:http://blog.csdn.net/yangliuy/article/details/8457329

本系列博文介绍常见概率语言模型及其变形模型,主要总结PLSA、LDA及LDA的变形模型及参数Inference方法。初步计划内容如下

第一篇:PLSA及EM算法

第二篇:LDA及Gibbs Samping

第三篇:LDA变形模型-Twitter LDA,TimeUserLDA,ATM,Labeled-LDA,MaxEnt-LDA等

第四篇:基于变形LDA的paper分类总结(bibliography)

第五篇:LDA Gibbs Sampling 的JAVA实现


第五篇 LDA Gibbs Sampling的JAVA 实现

在本系列博文的前两篇,我们系统介绍了PLSA, LDA以及它们的参数Inference 方法,重点分析了模型表示和公式推导部分。曾有位学者说,“做研究要顶天立地”,意思是说做研究空有模型和理论还不够,我们还得有扎实的程序code和真实数据的实验结果来作为支撑。本文就重点分析 LDA Gibbs Sampling的JAVA 实现,并给出apply到newsgroup18828新闻文档集上得出的Topic建模结果。

本项目Github地址 https://github.com/yangliuy/LDAGibbsSampling


1、文档集预处理

要用LDA对文本进行topic建模,首先要对文本进行预处理,包括token,去停用词,stem,去noise词,去掉低频词等等。当语料库比较大时,我们也可以不进行stem。然后将文本转换成term的index表示形式,因为后面实现LDA的过程中经常需要在term和index之间进行映射。Documents类的实现如下,里面定义了Document内部类,用于描述文本集合中的文档。

[java] view plaincopy
  1. package liuyang.nlp.lda.main;  
  2.   
  3. import java.io.File;  
  4. import java.util.ArrayList;  
  5. import java.util.HashMap;  
  6. import java.util.Map;  
  7. import java.util.regex.Matcher;  
  8. import java.util.regex.Pattern;  
  9.   
  10. import liuyang.nlp.lda.com.FileUtil;  
  11. import liuyang.nlp.lda.com.Stopwords;  
  12.   
  13. /**Class for corpus which consists of M documents 
  14.  * @author yangliu 
  15.  * @blog http://blog.csdn.net/yangliuy 
  16.  * @mail yangliuyx@gmail.com 
  17.  */  
  18.   
  19. public class Documents {  
  20.       
  21.     ArrayList<Document> docs;   
  22.     Map<String, Integer> termToIndexMap;  
  23.     ArrayList<String> indexToTermMap;  
  24.     Map<String,Integer> termCountMap;  
  25.       
  26.     public Documents(){  
  27.         docs = new ArrayList<Document>();  
  28.         termToIndexMap = new HashMap<String, Integer>();  
  29.         indexToTermMap = new ArrayList<String>();  
  30.         termCountMap = new HashMap<String, Integer>();  
  31.     }  
  32.       
  33.     public void readDocs(String docsPath){  
  34.         for(File docFile : new File(docsPath).listFiles()){  
  35.             Document doc = new Document(docFile.getAbsolutePath(), termToIndexMap, indexToTermMap, termCountMap);  
  36.             docs.add(doc);  
  37.         }  
  38.     }  
  39.       
  40.     public static class Document {    
  41.         private String docName;  
  42.         int[] docWords;  
  43.           
  44.         public Document(String docName, Map<String, Integer> termToIndexMap, ArrayList<String> indexToTermMap, Map<String, Integer> termCountMap){  
  45.             this.docName = docName;  
  46.             //Read file and initialize word index array  
  47.             ArrayList<String> docLines = new ArrayList<String>();  
  48.             ArrayList<String> words = new ArrayList<String>();  
  49.             FileUtil.readLines(docName, docLines);  
  50.             for(String line : docLines){  
  51.                 FileUtil.tokenizeAndLowerCase(line, words);  
  52.             }  
  53.             //Remove stop words and noise words  
  54.             for(int i = 0; i < words.size(); i++){  
  55.                 if(Stopwords.isStopword(words.get(i)) || isNoiseWord(words.get(i))){  
  56.                     words.remove(i);  
  57.                     i--;  
  58.                 }  
  59.             }  
  60.             //Transfer word to index  
  61.             this.docWords = new int[words.size()];  
  62.             for(int i = 0; i < words.size(); i++){  
  63.                 String word = words.get(i);  
  64.                 if(!termToIndexMap.containsKey(word)){  
  65.                     int newIndex = termToIndexMap.size();  
  66.                     termToIndexMap.put(word, newIndex);  
  67.                     indexToTermMap.add(word);  
  68.                     termCountMap.put(word, new Integer(1));  
  69.                     docWords[i] = newIndex;  
  70.                 } else {  
  71.                     docWords[i] = termToIndexMap.get(word);  
  72.                     termCountMap.put(word, termCountMap.get(word) + 1);  
  73.                 }  
  74.             }  
  75.             words.clear();  
  76.         }  
  77.           
  78.         public boolean isNoiseWord(String string) {  
  79.             // TODO Auto-generated method stub  
  80.             string = string.toLowerCase().trim();  
  81.             Pattern MY_PATTERN = Pattern.compile(".*[a-zA-Z]+.*");  
  82.             Matcher m = MY_PATTERN.matcher(string);  
  83.             // filter @xxx and URL  
  84.             if(string.matches(".*www\\..*") || string.matches(".*\\.com.*") ||   
  85.                     string.matches(".*http:.*") )  
  86.                 return true;  
  87.             if (!m.matches()) {  
  88.                 return true;  
  89.             } else  
  90.                 return false;  
  91.         }  
  92.           
  93.     }  
  94. }  

2 LDA Gibbs Sampling

文本预处理完毕后我们就可以实现LDA Gibbs Sampling。 首先我们要定义需要的参数,我的实现中在程序中给出了参数默认值,同时也支持配置文件覆盖,程序默认优先选用配置文件的参数设置。整个算法流程包括模型初始化,迭代Inference,不断更新主题和待估计参数,最后输出收敛时的参数估计结果。

包含主函数的配置参数解析类如下:

[java] view plaincopy
  1. package liuyang.nlp.lda.main;  
  2.   
  3. import java.io.File;  
  4. import java.io.IOException;  
  5. import java.util.ArrayList;  
  6.   
  7. import liuyang.nlp.lda.com.FileUtil;  
  8. import liuyang.nlp.lda.conf.ConstantConfig;  
  9. import liuyang.nlp.lda.conf.PathConfig;  
  10.   
  11. /**Liu Yang's implementation of Gibbs Sampling of LDA 
  12.  * @author yangliu 
  13.  * @blog http://blog.csdn.net/yangliuy 
  14.  * @mail yangliuyx@gmail.com 
  15.  */  
  16.   
  17. public class LdaGibbsSampling {  
  18.       
  19.     public static class modelparameters {  
  20.         float alpha = 0.5f; //usual value is 50 / K  
  21.         float beta = 0.1f;//usual value is 0.1  
  22.         int topicNum = 100;  
  23.         int iteration = 100;  
  24.         int saveStep = 10;  
  25.         int beginSaveIters = 50;  
  26.     }  
  27.       
  28.     /**Get parameters from configuring file. If the  
  29.      * configuring file has value in it, use the value. 
  30.      * Else the default value in program will be used 
  31.      * @param ldaparameters 
  32.      * @param parameterFile 
  33.      * @return void 
  34.      */  
  35.     private static void getParametersFromFile(modelparameters ldaparameters,  
  36.             String parameterFile) {  
  37.         // TODO Auto-generated method stub  
  38.         ArrayList<String> paramLines = new ArrayList<String>();  
  39.         FileUtil.readLines(parameterFile, paramLines);  
  40.         for(String line : paramLines){  
  41.             String[] lineParts = line.split("\t");  
  42.             switch(parameters.valueOf(lineParts[0])){  
  43.             case alpha:  
  44.                 ldaparameters.alpha = Float.valueOf(lineParts[1]);  
  45.                 break;  
  46.             case beta:  
  47.                 ldaparameters.beta = Float.valueOf(lineParts[1]);  
  48.                 break;  
  49.             case topicNum:  
  50.                 ldaparameters.topicNum = Integer.valueOf(lineParts[1]);  
  51.                 break;  
  52.             case iteration:  
  53.                 ldaparameters.iteration = Integer.valueOf(lineParts[1]);  
  54.                 break;  
  55.             case saveStep:  
  56.                 ldaparameters.saveStep = Integer.valueOf(lineParts[1]);  
  57.                 break;  
  58.             case beginSaveIters:  
  59.                 ldaparameters.beginSaveIters = Integer.valueOf(lineParts[1]);  
  60.                 break;  
  61.             }  
  62.         }  
  63.     }  
  64.       
  65.     public enum parameters{  
  66.         alpha, beta, topicNum, iteration, saveStep, beginSaveIters;  
  67.     }  
  68.       
  69.     /** 
  70.      * @param args 
  71.      * @throws IOException  
  72.      */  
  73.     public static void main(String[] args) throws IOException {  
  74.         // TODO Auto-generated method stub  
  75.         String originalDocsPath = PathConfig.ldaDocsPath;  
  76.         String resultPath = PathConfig.LdaResultsPath;  
  77.         String parameterFile= ConstantConfig.LDAPARAMETERFILE;  
  78.           
  79.         modelparameters ldaparameters = new modelparameters();  
  80.         getParametersFromFile(ldaparameters, parameterFile);  
  81.         Documents docSet = new Documents();  
  82.         docSet.readDocs(originalDocsPath);  
  83.         System.out.println("wordMap size " + docSet.termToIndexMap.size());  
  84.         FileUtil.mkdir(new File(resultPath));  
  85.         LdaModel model = new LdaModel(ldaparameters);  
  86.         System.out.println("1 Initialize the model ...");  
  87.         model.initializeModel(docSet);  
  88.         System.out.println("2 Learning and Saving the model ...");  
  89.         model.inferenceModel(docSet);  
  90.         System.out.println("3 Output the final model ...");  
  91.         model.saveIteratedModel(ldaparameters.iteration, docSet);  
  92.         System.out.println("Done!");  
  93.     }  
  94. }  

LDA 模型实现类如下

[java] view plaincopy
  1. package liuyang.nlp.lda.main;  
  2.   
  3. /**Class for Lda model 
  4.  * @author yangliu 
  5.  * @blog http://blog.csdn.net/yangliuy 
  6.  * @mail yangliuyx@gmail.com 
  7.  */  
  8. import java.io.BufferedWriter;  
  9. import java.io.FileWriter;  
  10. import java.io.IOException;  
  11. import java.util.ArrayList;  
  12. import java.util.Collections;  
  13. import java.util.Comparator;  
  14. import java.util.List;  
  15.   
  16. import liuyang.nlp.lda.com.FileUtil;  
  17. import liuyang.nlp.lda.conf.PathConfig;  
  18.   
  19. public class LdaModel {  
  20.       
  21.     int [][] doc;//word index array  
  22.     int V, K, M;//vocabulary size, topic number, document number  
  23.     int [][] z;//topic label array  
  24.     float alpha; //doc-topic dirichlet prior parameter   
  25.     float beta; //topic-word dirichlet prior parameter  
  26.     int [][] nmk;//given document m, count times of topic k. M*K  
  27.     int [][] nkt;//given topic k, count times of term t. K*V  
  28.     int [] nmkSum;//Sum for each row in nmk  
  29.     int [] nktSum;//Sum for each row in nkt  
  30.     double [][] phi;//Parameters for topic-word distribution K*V  
  31.     double [][] theta;//Parameters for doc-topic distribution M*K  
  32.     int iterations;//Times of iterations  
  33.     int saveStep;//The number of iterations between two saving  
  34.     int beginSaveIters;//Begin save model at this iteration  
  35.       
  36.     public LdaModel(LdaGibbsSampling.modelparameters modelparam) {  
  37.         // TODO Auto-generated constructor stub  
  38.         alpha = modelparam.alpha;  
  39.         beta = modelparam.beta;  
  40.         iterations = modelparam.iteration;  
  41.         K = modelparam.topicNum;  
  42.         saveStep = modelparam.saveStep;  
  43.         beginSaveIters = modelparam.beginSaveIters;  
  44.     }  
  45.   
  46.     public void initializeModel(Documents docSet) {  
  47.         // TODO Auto-generated method stub  
  48.         M = docSet.docs.size();  
  49.         V = docSet.termToIndexMap.size();  
  50.         nmk = new int [M][K];  
  51.         nkt = new int[K][V];  
  52.         nmkSum = new int[M];  
  53.         nktSum = new int[K];  
  54.         phi = new double[K][V];  
  55.         theta = new double[M][K];  
  56.           
  57.         //initialize documents index array  
  58.         doc = new int[M][];  
  59.         for(int m = 0; m < M; m++){  
  60.             //Notice the limit of memory  
  61.             int N = docSet.docs.get(m).docWords.length;  
  62.             doc[m] = new int[N];  
  63.             for(int n = 0; n < N; n++){  
  64.                 doc[m][n] = docSet.docs.get(m).docWords[n];  
  65.             }  
  66.         }  
  67.           
  68.         //initialize topic lable z for each word  
  69.         z = new int[M][];  
  70.         for(int m = 0; m < M; m++){  
  71.             int N = docSet.docs.get(m).docWords.length;  
  72.             z[m] = new int[N];  
  73.             for(int n = 0; n < N; n++){  
  74.                 int initTopic = (int)(Math.random() * K);// From 0 to K - 1  
  75.                 z[m][n] = initTopic;  
  76.                 //number of words in doc m assigned to topic initTopic add 1  
  77.                 nmk[m][initTopic]++;  
  78.                 //number of terms doc[m][n] assigned to topic initTopic add 1  
  79.                 nkt[initTopic][doc[m][n]]++;  
  80.                 // total number of words assigned to topic initTopic add 1  
  81.                 nktSum[initTopic]++;  
  82.             }  
  83.              // total number of words in document m is N  
  84.             nmkSum[m] = N;  
  85.         }  
  86.     }  
  87.   
  88.     public void inferenceModel(Documents docSet) throws IOException {  
  89.         // TODO Auto-generated method stub  
  90.         if(iterations < saveStep + beginSaveIters){  
  91.             System.err.println("Error: the number of iterations should be larger than " + (saveStep + beginSaveIters));  
  92.             System.exit(0);  
  93.         }  
  94.         for(int i = 0; i < iterations; i++){  
  95.             System.out.println("Iteration " + i);  
  96.             if((i >= beginSaveIters) && (((i - beginSaveIters) % saveStep) == 0)){  
  97.                 //Saving the model  
  98.                 System.out.println("Saving model at iteration " + i +" ... ");  
  99.                 //Firstly update parameters  
  100.                 updateEstimatedParameters();  
  101.                 //Secondly print model variables  
  102.                 saveIteratedModel(i, docSet);  
  103.             }  
  104.               
  105.             //Use Gibbs Sampling to update z[][]  
  106.             for(int m = 0; m < M; m++){  
  107.                 int N = docSet.docs.get(m).docWords.length;  
  108.                 for(int n = 0; n < N; n++){  
  109.                     // Sample from p(z_i|z_-i, w)  
  110.                     int newTopic = sampleTopicZ(m, n);  
  111.                     z[m][n] = newTopic;  
  112.                 }  
  113.             }  
  114.         }  
  115.     }  
  116.       
  117.     private void updateEstimatedParameters() {  
  118.         // TODO Auto-generated method stub  
  119.         for(int k = 0; k < K; k++){  
  120.             for(int t = 0; t < V; t++){  
  121.                 phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta);  
  122.             }  
  123.         }  
  124.           
  125.         for(int m = 0; m < M; m++){  
  126.             for(int k = 0; k < K; k++){  
  127.                 theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);  
  128.             }  
  129.         }  
  130.     }  
  131.   
  132.     private int sampleTopicZ(int m, int n) {  
  133.         // TODO Auto-generated method stub  
  134.         // Sample from p(z_i|z_-i, w) using Gibbs upde rule  
  135.           
  136.         //Remove topic label for w_{m,n}  
  137.         int oldTopic = z[m][n];  
  138.         nmk[m][oldTopic]--;  
  139.         nkt[oldTopic][doc[m][n]]--;  
  140.         nmkSum[m]--;  
  141.         nktSum[oldTopic]--;  
  142.           
  143.         //Compute p(z_i = k|z_-i, w)  
  144.         double [] p = new double[K];  
  145.         for(int k = 0; k < K; k++){  
  146.             p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta) * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);  
  147.         }  
  148.           
  149.         //Sample a new topic label for w_{m, n} like roulette  
  150.         //Compute cumulated probability for p  
  151.         for(int k = 1; k < K; k++){  
  152.             p[k] += p[k - 1];  
  153.         }  
  154.         double u = Math.random() * p[K - 1]; //p[] is unnormalised  
  155.         int newTopic;  
  156.         for(newTopic = 0; newTopic < K; newTopic++){  
  157.             if(u < p[newTopic]){  
  158.                 break;  
  159.             }  
  160.         }  
  161.           
  162.         //Add new topic label for w_{m, n}  
  163.         nmk[m][newTopic]++;  
  164.         nkt[newTopic][doc[m][n]]++;  
  165.         nmkSum[m]++;  
  166.         nktSum[newTopic]++;  
  167.         return newTopic;  
  168.     }  
  169.   
  170.     public void saveIteratedModel(int iters, Documents docSet) throws IOException {  
  171.         // TODO Auto-generated method stub  
  172.         //lda.params lda.phi lda.theta lda.tassign lda.twords  
  173.         //lda.params  
  174.         String resPath = PathConfig.LdaResultsPath;  
  175.         String modelName = "lda_" + iters;  
  176.         ArrayList<String> lines = new ArrayList<String>();  
  177.         lines.add("alpha = " + alpha);  
  178.         lines.add("beta = " + beta);  
  179.         lines.add("topicNum = " + K);  
  180.         lines.add("docNum = " + M);  
  181.         lines.add("termNum = " + V);  
  182.         lines.add("iterations = " + iterations);  
  183.         lines.add("saveStep = " + saveStep);  
  184.         lines.add("beginSaveIters = " + beginSaveIters);  
  185.         FileUtil.writeLines(resPath + modelName + ".params", lines);  
  186.           
  187.         //lda.phi K*V  
  188.         BufferedWriter writer = new BufferedWriter(new FileWriter(resPath + modelName + ".phi"));         
  189.         for (int i = 0; i < K; i++){  
  190.             for (int j = 0; j < V; j++){  
  191.                 writer.write(phi[i][j] + "\t");  
  192.             }  
  193.             writer.write("\n");  
  194.         }  
  195.         writer.close();  
  196.           
  197.         //lda.theta M*K  
  198.         writer = new BufferedWriter(new FileWriter(resPath + modelName + ".theta"));  
  199.         for(int i = 0; i < M; i++){  
  200.             for(int j = 0; j < K; j++){  
  201.                 writer.write(theta[i][j] + "\t");  
  202.             }  
  203.             writer.write("\n");  
  204.         }  
  205.         writer.close();  
  206.           
  207.         //lda.tassign  
  208.         writer = new BufferedWriter(new FileWriter(resPath + modelName + ".tassign"));  
  209.         for(int m = 0; m < M; m++){  
  210.             for(int n = 0; n < doc[m].length; n++){  
  211.                 writer.write(doc[m][n] + ":" + z[m][n] + "\t");  
  212.             }  
  213.             writer.write("\n");  
  214.         }  
  215.         writer.close();  
  216.           
  217.         //lda.twords phi[][] K*V  
  218.         writer = new BufferedWriter(new FileWriter(resPath + modelName + ".twords"));  
  219.         int topNum = 20//Find the top 20 topic words in each topic  
  220.         for(int i = 0; i < K; i++){  
  221.             List<Integer> tWordsIndexArray = new ArrayList<Integer>();   
  222.             for(int j = 0; j < V; j++){  
  223.                 tWordsIndexArray.add(new Integer(j));  
  224.             }  
  225.             Collections.sort(tWordsIndexArray, new LdaModel.TwordsComparable(phi[i]));  
  226.             writer.write("topic " + i + "\t:\t");  
  227.             for(int t = 0; t < topNum; t++){  
  228.                 writer.write(docSet.indexToTermMap.get(tWordsIndexArray.get(t)) + " " + phi[i][tWordsIndexArray.get(t)] + "\t");  
  229.             }  
  230.             writer.write("\n");  
  231.         }  
  232.         writer.close();  
  233.     }  
  234.       
  235.     public class TwordsComparable implements Comparator<Integer> {  
  236.           
  237.         public double [] sortProb; // Store probability of each word in topic k  
  238.           
  239.         public TwordsComparable (double[] sortProb){  
  240.             this.sortProb = sortProb;  
  241.         }  
  242.   
  243.         @Override  
  244.         public int compare(Integer o1, Integer o2) {  
  245.             // TODO Auto-generated method stub  
  246.             //Sort topic word index according to the probability of each word in topic k  
  247.             if(sortProb[o1] > sortProb[o2]) return -1;  
  248.             else if(sortProb[o1] < sortProb[o2]) return 1;  
  249.             else return 0;  
  250.         }  
  251.     }  
  252. }  

程序的实现细节可以参考我在程序中给出的注释,如果理解LDA Gibbs Sampling的算法流程,上面的代码很好理解。其实排除输入输出和参数解析的代码,标准LDA 的Gibbs sampling只需要不到200行程序就可以搞定。当然,里面有很多可以考虑优化和变形的地方。

还有com和conf目录下的源文件分别放置常用函数和配置类,完整的JAVA工程见Github https://github.com/yangliuy/LDAGibbsSampling


3 用LDA Gibbs Sampling对Newsgroup 18828文档集进行主题分析

下面我们给出将上面的LDA Gibbs Sampling的实现Apply到Newsgroup 18828文档集进行主题分析的结果。 我实验时用到的数据已经上传到Github中,感兴趣的朋友可以直接从Github中下载工程运行。 我在Newsgroup 18828文档集随机选择了9个目录,每个目录下选择一个文档,将它们放置在data\LdaOriginalDocs目录下,我设定的模型参数如下

[plain] view plaincopy
  1. alpha   0.5  
  2. beta    0.1  
  3. topicNum    10  
  4. iteration   100  
  5. saveStep    10  
  6. beginSaveIters  80  

即设定alpha和beta的值为0.5和0.1, Topic数目为10,迭代100次,从第80次开始保存模型结果,每10次保存一次。

经过100次Gibbs Sampling迭代后,程序输出10个Topic下top的topic words以及对应的概率值如下



我们可以看到虽然是unsupervised learning, LDA分析出来的Topic words还是非常make sense的。比如第5个topic是宗教类的,第6个topic是天文类的,第7个topic是计算机类的。程序的输出还包括模型参数.param文件,topic-word分布phi向量.phi文件,doc-topic分布theta向量.theta文件以及每个文档中每个单词分配到的主题label的.tassign文件。感兴趣的朋友可以从Github https://github.com/yangliuy/LDAGibbsSampling 下载完整工程自己换用其他数据集进行主题分析实验。 本程序是初步实现版本,如果大家发现任何问题或者bug欢迎交流,我第一时间在Github修复bug更新版本。


4 参考文献

[1] Christopher M. Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics). Springer-Verlag New York, Inc., Secaucus, NJ, USA, 2006.
[2] Gregor Heinrich. Parameter estimation for text analysis. Technical report, 2004.
[3] Wang Yi. Distributed Gibbs Sampling of Latent Topic Models: The Gritty Details Technical report, 2005.

[4] Wayne Xin Zhao, Note for pLSA and LDA, Technical report, 2011.

[5] Freddy Chong Tat Chua. Dimensionality reduction and clustering of text documents.Technical report, 2009.

[6] Jgibblda, http://jgibblda.sourceforge.net/

[7]David M. Blei, Andrew Y. Ng, and Michael I. Jordan. 2003. Latent dirichlet allocation. J. Mach. Learn. Res. 3 (March 2003), 993-1022.
0 0
原创粉丝点击