LDA理解以及源码分析(二)

来源:互联网 发布:蓝巨星软件 编辑:程序博客网 时间:2024/05/24 06:05

LDA系列的讲解分多个博文给出,主要大纲如下:

  • LDA相关的基础知识
    • 什么是共轭
    • multinomial分布
    • Dirichlet分布
  • LDA in text
    • LAD的概率图模型
    • LDA的参数推导
    • 伪代码
  • GibbsLDA++-0.2源码分析
  • Python实现GibbsLDA
  • 参考资料

GibbsLDA++-0.2源码分析

GibbsLDA++-0.2工具包下载地址为:下载

工具包里docs文件夹里有说明文件GibbsLDA++Manual.pdf,按照要求编译就可以使用,很方便。(具体使用方法后面给出)

代码在文件夹src中,主要有这么几个类:dataset,model, strtokenizer,utils以及lda.cpp, constants.h文件。

dataset

    //两个全局变量,分别存储word和id的对应。    typedef map<string, int> mapword2id;    typedef map<int, string> mapid2word;    //类document    class document {    public:        //保存每个word对应的id        int * words;         string rawstr;        //文章的words总数        int length;         document() {}        document(int length) {}        document(int length, int * words) {}        document(int length, int * words, string rawstr) {}        document(vector<int> & doc) {}        document(vector<int> & doc, string rawstr) {}        ~document() {}    };    class dataset {    public:        document ** docs;         document ** _docs; // used only for inference        map<int, int> _id2id; // also used only for inference        int M; // documents总数        int V; // words总数        dataset() {}        dataset(int M) {}           ~dataset() {}        void deallocate() {}        void add_doc(document * doc, int idx) {}           void _add_doc(document * doc, int idx) {}               //根据pword2id写wordmap,文件每行都是“word id”的格式        static int write_wordmap(string wordmapfile, mapword2id * pword2id);        //读wordmap中的内容,存储到pword2id中        static int read_wordmap(string wordmapfile, mapword2id * pword2id);        static int read_wordmap(string wordmapfile, mapid2word * pid2word);        int read_trndata(string dfile, string wordmapfile);//读训练文件        int read_newdata(string dfile, string wordmapfile);//读推断的新文件        int read_newdata_withrawstrs(string dfile, string wordmapfile);    };

utils

class utils {public:    // 解析命令行参数    static int parse_args(int argc, char ** argv, model * pmodel);    // 读<model_name>.others文件并解析模型参数    static int read_and_parse(string filename, model * model);     // 为当前的迭代生成模型名字,命令行有个参数会指定什么时候需要保存模型,iter=-1时最后才保存    static string generate_model_name(int iter);      // 排序    static void sort(vector<double> & probs, vector<int> & words);    static void quicksort(vector<pair<int, double> > & vect, int left, int right);};

strtokenizer

class strtokenizer {protected:    vector<string> tokens; //存储分后的词    int idx;//tokens的索引public:    strtokenizer(string str, string seperators = " ");//对str按照seperators分词,存在tokens中        void parse(string str, string seperators);//同上    int count_tokens();   //返回tokens的大小    string next_token();   //返回idx当前索引的token值    void start_scan();   //idx = 0    string token(int i);  // 返回tokens[i]的值};

model类是最重要的类,实现核心代码

    class model {    public:        // fixed options        string wordmapfile;     // file that contains word map [string -> integer id]        string trainlogfile;    // training log file        string tassign_suffix;  // suffix for topic assignment file        string theta_suffix;    // suffix for theta file        string phi_suffix;      // suffix for phi file        string others_suffix;   // suffix for file containing other parameters        string twords_suffix;   // suffix for file containing words-per-topics        string dir;         // model directory        string dfile;       // data file            string model_name;      // model name        int model_status;       // model status:本代码提供est,estc和inf三种模式        dataset * ptrndata; // 指针,指向训练文档集        dataset * pnewdata; // 指针,指向推断的新文档集        mapid2word id2word; // word map [int => string]          /*下面是模型参数和变量*/         int M; // documents数量        int V; // words数量(不重复)        int K; // topics数量        double alpha, beta; // LDA超参数         int niters; // Gibbs采样迭代的次数        int liter; // 需要将模型保存为文件的迭代次数        int savestep; // saving period        int twords; // 输出每个topic的前twords个词        int withrawstrs;        /*下面是LDA核心算法中的变量*/        double * p; // 采样的临时变量        int ** z; // M x doc.size(), 文档中words的topic分布        int ** nw; // V x K,nw[i][j]词i在主题j上出现的次数         int ** nd; // M x K,nd[i][j]: 文章i中属于主题j的词的个数         int * nwsum; // K, nwsum[j]: 属于主题j的词的数量        int * ndsum; // M, ndsum[i]: 文章i中词的数量         double ** theta; // M x K, document-topic分布        double ** phi; // K x V, topic-word分布        /*下面的变量只有在推断时用到*/        int inf_liter;        int newM;        int newV;        int ** newz;        int ** newnw;        int ** newnd;        int * newnwsum;        int * newndsum;        double ** newtheta;        double ** newphi;        // --------------------------------------        model() {                  ~model();          // 对变量设置初值,构造函数中调用        void set_default_values();           // 解析命令行得到LDA参数        int parse_args(int argc, char ** argv);         // 初始化模型,调用后面的init()设置初值        int init(int argc, char ** argv);         // 加载已有的LDA模型继续训练或者推断        int load_model(string model_name);        // 保存LDA模型文件        int save_model(string model_name);        int save_model_tassign(string filename);        int save_model_theta(string filename);        int save_model_phi(string filename);        int save_model_others(string filename);        int save_model_twords(string filename);        // 保存inf的结果        int save_inf_model(string model_name);        int save_inf_model_tassign(string filename);        int save_inf_model_newtheta(string filename);        int save_inf_model_newphi(string filename);        int save_inf_model_others(string filename);        int save_inf_model_twords(string filename);        // 训练模型前的初始化,init()函数中会调用        int init_est();        int init_estc();        // LDA核心算法!!!下面详细介绍        void estimate();        int sampling(int m, int n);        //计算迭代之后的theta和phi        void compute_theta();        void compute_phi();         // 推断初始化,在init()中调用        int init_inf();        // 在已知LDA模型中,对未知的新文档进行推断        void inference();        int inf_sampling(int m, int n);        void compute_newtheta();        void compute_newphi();    };

下面详细分析estimate()函数和sampling()函数

    void model::estimate() {    if (twords > 0) {//如果要求输出主题下的词,则读wordmap文件排序后输出        dataset::read_wordmap(dir + wordmapfile, &id2word);    }    printf("Sampling %d iterations!\n", niters);    int last_iter = liter;//记录迭代次数    for (liter = last_iter + 1; liter <= niters + last_iter; liter++) {        printf("Iteration %d ...\n", liter);        // 对所有的z[m][n]采样一个主题        for (int m = 0; m < M; m++) {            for (int n = 0; n < ptrndata->docs[m]->length; n++) {                // sampling根据Gibbs采样公式p(z_i|z_-i, w)采样                int topic = sampling(m, n);                z[m][n] = topic;            }        }        //判断是否需要保存这次迭代的模型        if (savestep > 0) {            if (liter % savestep == 0) {                printf("Saving the model at iteration %d ...\n", liter);                compute_theta();                compute_phi();                save_model(utils::generate_model_name(liter));            }        }    }    //迭代结束,根据theta和phi的公式计算值,并且保存模型文件    printf("Gibbs sampling completed!\n");    printf("Saving the final model!\n");    compute_theta();    compute_phi();    liter--;    save_model(utils::generate_model_name(-1));    }
    int model::sampling(int m, int n) {        //采样算法,排除当前词,根据其他词的主题分布计算当前词的分布        int topic = z[m][n];//获得当前主题        int w = ptrndata->docs[m]->words[n]; //获得当前词的id        //首先,排除当前词,即涉及的统计变量个数减一        nw[w][topic] -= 1;        nd[m][topic] -= 1;        nwsum[topic] -= 1;        ndsum[m] -= 1;        //临时计算变量        double Vbeta = V * beta;        double Kalpha = K * alpha;            /*通过累加来采样*/        //计算当前词在每个主题下的概率        for (int k = 0; k < K; k++) {            p[k] = (nw[w][k] + beta) / (nwsum[k] + Vbeta) *                (nd[m][k] + alpha) / (ndsum[m] + Kalpha);        }        // 概率和累加        for (int k = 1; k < K; k++) {            p[k] += p[k - 1];        }        //随机生成小数        double u = ((double)random() / RAND_MAX) * p[K - 1];        //根据随机生成的值,采样主题编号        for (topic = 0; topic < K; topic++) {            if (p[topic] > u) {                break;            }        }        //新的topic对应的参数增加1        nw[w][topic] += 1;        nd[m][topic] += 1;        nwsum[topic] += 1;        ndsum[m] += 1;              return topic;    }
1 0
原创粉丝点击