词向量源码解析:(3.3)GloVe源码解析之cooccur

来源:互联网 发布:淘宝怎么好评返现 编辑:程序博客网 时间:2024/06/05 15:57

cooccur代码读入语料和vocab_count建立的词典,输出共现矩阵。我们可以先想想这个任务应该怎么做。一个很直接简单的方法就是用二维数组去存。比如词典中有10000单词,那么我们初始化一个10000*10000的二维数组M[10000][10000]。其中第i行第j列就记录第i个单词和第j个单词在语料中共现的次数。当我们扫描语料,得到一个单词对儿的时候,就在相应的位置加1。这么看来建立共现矩阵的代码几行就搞定了。

这么做忽略了一个问题,就是共现矩阵太过稀疏,大多数的单词从来没有共现过,通过二维矩阵存非常浪费内存。实际上一般内存也根本存不下甚至中等规模大小的语料。所以用二维数据存储共现矩阵,在大规模语料上面是不可行的。

既然显示的存储共现矩阵不行,我们可以用稀疏矩阵去存储。扫描语料的过程中不断地更新稀疏矩阵,从而得到最终的共现矩阵。我这里说一种具体的方法:首先不断地从文件中读取单词对(pair),读到内存写满为止。对这些pair进行排序以及汇总,然后写出到文件。以此类推,我们就能得到很多文件,每个文件里面存的都是对部分的语料的共现矩阵。然后我们对这些共生矩阵再进行汇总。后面看代码的时候我们会看到具体的细节。这里也不推导空间复杂度了,一般2G,4G的内存就能在巨大的语料上生成共现矩阵了。

在GloVe代码中,它使用了混合的数据结构去存储共现矩阵。当我们的单词按照它们的频数排好序以后,比如the的id是1,of的id是2然后以此类推。我们会发现共现矩阵的左上角是非常稠密的。当一个pair中两个单词的id都很小(它们的频数很大),他们很有可能共现。而共现矩阵的其它地方很稀疏。比如右下角,低频词不大可能共现。基于这个观察,GloVe中使用二维矩阵去显式的存储共现矩阵的左上角,剩下的部分用我上段中提到的方法建立共现矩阵。所以,共现矩阵的左上角是一直在内存中的,剩余部分会写出最后再汇总。这里再说一下什么是左上角,GloVe中认为word pair中两个单词的id的乘积小于某个阈值时算是在左上角。阈值时100的时候,id为9和id为10的单词组成的pair就算在共现矩阵的左上角,因为9乘以10小于100.

知道了GloVe建立共现矩阵的基本流程,我们下面开始看代码。共现矩阵左上角以外的部分使用一个三元组去存储的。分别是第一个单词的id第二个单词的id以及他们的共现次数。GloVe对共现次数的计算是随着单词之间距离递减的,比如两个单词距离为5,那么算他们共现五分之一。所以共现次数是浮点型。建立共现矩阵自然也需要词典,所以内存中会维护词典。词典中记录单词的字符串和id,不再记录频数,单词的id是根据频数来的,频数越高id越小。

typedef struct cooccur_rec {//共现矩阵的元素用三元组存储
    int word1;
    int word2;
    real val;
} CREC;


typedef struct cooccur_rec_id {//在第二阶段对多个文件进行合并的时候,我们需要知道三元组来自于哪一个文件,id就是存储所在的文件的
    int word1;
    int word2;
    real val;
    int id;
} CRECID;


typedef struct hashrec {
    char *word;
    long long id;
    struct hashrec *next;
} HASHREC;

词典部分的代码和之前区别不大

HASHREC ** inithashtable() {
    int i;
    HASHREC **ht;
    ht = (HASHREC **) malloc( sizeof(HASHREC *) * TSIZE );
    for (i = 0; i < TSIZE; i++) ht[i] = (HASHREC *) NULL;
    return(ht);
}


/* Search hash table for given string, return record if found, else NULL */
HASHREC *hashsearch(HASHREC **ht, char *w) {//
    HASHREC *htmp, *hprv;
    unsigned int hval = HASHFN(w, TSIZE, SEED);//先得到单词的哈希值
    for (hprv = NULL, htmp=ht[hval]; htmp != NULL && scmp(htmp->word, w) != 0; hprv = htmp, htmp = htmp->next);//在链表中寻找单词
    if ( htmp != NULL && hprv!=NULL ) { // move to front on access//找到的话就把其移动到链表前端,提升效率
        hprv->next = htmp->next;
        htmp->next = ht[hval];
        ht[hval] = htmp;
    }
    return(htmp);
}


/* Insert string in hash table, check for duplicates which should be absent */
void hashinsert(HASHREC **ht, char *w, long long id) {
    HASHREC *htmp, *hprv;
    unsigned int hval = HASHFN(w, TSIZE, SEED);//先计算哈希值
    for (hprv = NULL, htmp = ht[hval]; htmp != NULL && scmp(htmp->word, w) != 0; hprv = htmp, htmp = htmp->next);//通过链表找到需要的单词
    if (htmp == NULL) {//没找到的话就插入
        htmp = (HASHREC *) malloc(sizeof(HASHREC));
        htmp->word = (char *) malloc(strlen(w) + 1);
        strcpy(htmp->word, w);
        htmp->id = id;
        htmp->next = NULL;
        if (hprv == NULL) ht[hval] = htmp;
        else hprv->next = htmp;
    }
    else fprintf(stderr, "Error, duplicate entry located: %s.\n",htmp->word);//插入得时候如果发现词典中有这个单词了,那么代表之前生成词典的时候出错了,出现了重复的单词。
    return;
}
从文件中读取一个单词,和word2vec中的ReadWord一样

/* Read word from input stream */
int get_word(char *word, FILE *fin) {
    int i = 0, ch;
    while (!feof(fin)) {
        ch = fgetc(fin);
        if (ch == 13) continue;
        if ((ch == ' ') || (ch == '\t') || (ch == '\n')) {
            if (i > 0) {
                if (ch == '\n') ungetc(ch, fin);
                break;
            }
            if (ch == '\n') return 1;
            else continue;
        }
        word[i++] = ch;
        if (i >= MAX_STRING_LENGTH - 1) i--;   // truncate words that exceed max length
    }
    word[i] = 0;
    return 0;
}

cr中存有三元组数组,要把它写到磁盘上,数组的大小是length,cr中的三元组都是排好序的,写出的时候要汇总一下

int write_chunk(CREC *cr, long long length, FILE *fout) {
    if (length == 0) return 0;


    long long a = 0;
    CREC old = cr[a];//记录着下一个写出的三元组
    
    for (a = 1; a < length; a++) {
        if (cr[a].word1 == old.word1 && cr[a].word2 == old.word2) {//如果顶端的三元组和要写出的word pair是一样的,那么先汇总,直到发现了word pair不一样的三元组,证明cr中包含这个word pair的三元组已经全部汇总了(cr是按照顺序排的)
            old.val += cr[a].val;
            continue;
        }
        fwrite(&old, sizeof(CREC), 1, fout);//如果word pair不一样,则写出之前的三元组
        old = cr[a];//把当前顶端的变成新的要写出的三元组
    }
    fwrite(&old, sizeof(CREC), 1, fout);//写出最后一个汇总好的三元组
    return 0;
}

对三元组排序的依据。先比第一个单词的id,再比第二个单词的id。id小(频数大)的排在前面,支持两种类型CREC和 CRECID

/* Check if two cooccurrence records are for the same two words, used for qsort */
int compare_crec(const void *a, const void *b) {
    int c;
    if ( (c = ((CREC *) a)->word1 - ((CREC *) b)->word1) != 0) return c;
    else return (((CREC *) a)->word2 - ((CREC *) b)->word2);
    
}


/* Check if two cooccurrence records are for the same two words */
int compare_crecid(CRECID a, CRECID b) {
    int c;
    if ( (c = a.word1 - b.word1) != 0) return c;
    else return a.word2 - b.word2;
}

下面说一下构建共现矩阵的核心代码,前面已经介绍过。我们用两种数据结构去存共现矩阵,显式的去存以及用三元组去存。下面的代码中会反映出这个逻辑。构建共现矩阵的过程又可以分成两个阶段,第一阶段是局部排序汇总。第二阶段是对局部排序汇总的文件进行合并,最终得到对整个语料排序汇总好的三元组,这也是最终的共现矩阵。下面的get_cooccurrence 函数对应着第一阶段。

/* Collect word-word cooccurrence counts from input stream */
int get_cooccurrence() {
    int flag, x, y, fidcounter = 1;
    long long a, j = 0, k, id, counter = 0, ind = 0, vocab_size, w1, w2, *lookup, *history;
    char format[20], filename[200], str[MAX_STRING_LENGTH + 1];
    FILE *fid, *foverflow;
    real *bigram_table, r;//bigram_table显式的存储左上角的共现矩阵,用一维数组模拟二维数组
    HASHREC *htmp, **vocab_hash = inithashtable();//词典
    CREC *cr = malloc(sizeof(CREC) * (overflow_length + 1));//我们能预估出大概内存中能存多少个三元组,满了以后就写出了
    history = malloc(sizeof(long long) * window_size);//存储上下文的单词id
    
    fprintf(stderr, "COUNTING COOCCURRENCES\n");
    if (verbose > 0) {
        fprintf(stderr, "window size: %d\n", window_size);
        if (symmetric == 0) fprintf(stderr, "context: asymmetric\n");
        else fprintf(stderr, "context: symmetric\n");
    }
    if (verbose > 1) fprintf(stderr, "max product: %lld\n", max_product);
    if (verbose > 1) fprintf(stderr, "overflow length: %lld\n", overflow_length);
    sprintf(format,"%%%ds %%lld", MAX_STRING_LENGTH); // Format to read from vocab file, which has (irrelevant) frequency data
    if (verbose > 1) fprintf(stderr, "Reading vocab from file \"%s\"...", vocab_file);
    fid = fopen(vocab_file,"r");//首先要把字典读入内存,得到每个单词的id
    if (fid == NULL) {fprintf(stderr,"Unable to open vocab file %s.\n",vocab_file); return 1;}
    while (fscanf(fid, format, str, &id) != EOF) hashinsert(vocab_hash, str, ++j); // Here id is not used: inserting vocab words into hash table with their frequency rank, j
    fclose(fid);
    vocab_size = j;
    j = 0;
    if (verbose > 1) fprintf(stderr, "loaded %lld words.\nBuilding lookup table...", vocab_size);
    
    /* Build auxiliary lookup table used to index into bigram_table */
    lookup = (long long *)calloc( vocab_size + 1, sizeof(long long) );//我们用一维数组模拟二维数组,根据之前讨论的左上角的定义,共现矩阵的每一行都有不同的元素的个数属于左上角,所以需要一个数组记录每一行前面几个元素,
    if (lookup == NULL) {
        fprintf(stderr, "Couldn't allocate memory!");
        return 1;
    }
    lookup[0] = 1;
    for (a = 1; a <= vocab_size; a++) {//根据左上角的定义构建共生矩阵
        if ((lookup[a] = max_product / a) < vocab_size) lookup[a] += lookup[a-1];
        else lookup[a] = lookup[a-1] + vocab_size;
    }
    if (verbose > 1) fprintf(stderr, "table contains %lld elements.\n",lookup[a-1]);
    
    /* Allocate memory for full array which will store all cooccurrence counts for words whose product of frequency ranks is less than max_product */
    bigram_table = (real *)calloc( lookup[a-1] , sizeof(real) );//左上角的共现矩阵
    if (bigram_table == NULL) {
        fprintf(stderr, "Couldn't allocate memory!");
        return 1;
    }
    
    fid = stdin;
    sprintf(format,"%%%ds",MAX_STRING_LENGTH);
    sprintf(filename,"%s_%04d.bin",file_head, fidcounter);//存左上角以外部分的文件
    foverflow = fopen(filename,"w");//存左上角共现矩阵的文件
    if (verbose > 1) fprintf(stderr,"Processing token: 0");
    
    /* For each token in input stream, calculate a weighted cooccurrence sum within window_size */
    while (1) {//对语料中的所有单词进行循环
        if (ind >= overflow_length - window_size) { // If overflow buffer is (almost) full, sort it and write it to temporary file//如果内存写满了
            qsort(cr, ind, sizeof(CREC), compare_crec);//首先排序
            write_chunk(cr,ind,foverflow);//汇总以后写出
            fclose(foverflow);
            fidcounter++;
            sprintf(filename,"%s_%04d.bin",file_head,fidcounter);//重新打开一个新的文件去写入
            foverflow = fopen(filename,"w");
            ind = 0;
        }
        flag = get_word(str, fid);//读一个单词进来
        if (feof(fid)) break;
        if (flag == 1) {j = 0; continue;} // Newline, reset line index (j)
        counter++;
        if ((counter%100000) == 0) if (verbose > 1) fprintf(stderr,"\033[19G%lld",counter);
        htmp = hashsearch(vocab_hash, str);//在词典中找单词的id
        if (htmp == NULL) continue; // Skip out-of-vocabulary words
        w2 = htmp->id; // Target word (frequency rank)//w2是当前单词的id
        for (k = j - 1; k >= ( (j > window_size) ? j - window_size : 0 ); k--) { // Iterate over all words to the left of target word, but not past beginning of line//j是当前单词在一行中的位置,这里找出当前词的上下文中的单词,和当前中心词配对,将来添加到共现矩阵中
            w1 = history[k % window_size]; // Context word (frequency rank)//上下文单词id
            if ( w1 < max_product/w2 ) { // Product is small enough to store in a full array//判断是不是在左上角
                bigram_table[lookup[w1-1] + w2 - 2] += 1.0/((real)(j-k)); // Weight by inverse of distance between words//显式的存共现矩阵,在数组上直接加就好
                if (symmetric > 0) bigram_table[lookup[w2-1] + w1 - 2] += 1.0/((real)(j-k)); // If symmetric context is used, exchange roles of w2 and w1 (ie look at right context too)//对上下文的定义,上下文只包括中心词左边的单词还是包括中心词两边的单词
            }
            else { // Product is too big, data is likely to be sparse. Store these entries in a temporary buffer to be sorted, merged (accumulated), and written to file when it gets full.//如果不在左上角,那么就按照三元组去存储
                cr[ind].word1 = w1;
                cr[ind].word2 = w2;
                cr[ind].val = 1.0/((real)(j-k));
                ind++; // Keep track of how full temporary buffer is
                if (symmetric > 0) { // Symmetric context
                    cr[ind].word1 = w2;
                    cr[ind].word2 = w1;
                    cr[ind].val = 1.0/((real)(j-k));
                    ind++;
                }
            }
        }
        history[j % window_size] = w2; // Target word is stored in circular buffer to become context word in the future//当前词变成了后面单词的上下文
        j++;
    }
    
    /* Write out temp buffer for the final time (it may not be full) *///全部语料扫过一遍,把还在内存中的cr数组的内容写到磁盘上面
    if (verbose > 1) fprintf(stderr,"\033[0GProcessed %lld tokens.\n",counter);
    qsort(cr, ind, sizeof(CREC), compare_crec);
    write_chunk(cr,ind,foverflow);
    sprintf(filename,"%s_0000.bin",file_head);
    
    /* Write out full bigram_table, skipping zeros *///把显式存储的共现矩阵写出来
    if (verbose > 1) fprintf(stderr, "Writing cooccurrences to disk");
    fid = fopen(filename,"w");
    j = 1e6;
    for (x = 1; x <= vocab_size; x++) {
        if ( (long long) (0.75*log(vocab_size / x)) < j) {j = (long long) (0.75*log(vocab_size / x)); if (verbose > 1) fprintf(stderr,".");} // log's to make it look (sort of) pretty
        for (y = 1; y <= (lookup[x] - lookup[x-1]); y++) {
            if ((r = bigram_table[lookup[x-1] - 2 + y]) != 0) {//也按照三元组写出,跳过共生矩阵元素为0的情况
                fwrite(&x, sizeof(int), 1, fid);
                fwrite(&y, sizeof(int), 1, fid);
                fwrite(&r, sizeof(real), 1, fid);
            }
        }
    }
    
    if (verbose > 1) fprintf(stderr,"%d files in total.\n",fidcounter + 1);
    fclose(fid);
    fclose(foverflow);
    free(cr);
    free(lookup);
    free(bigram_table);
    free(vocab_hash);
    return merge_files(fidcounter + 1); // Merge the sorted temporary files//第一阶段完成,进入第二阶段,对临时文件进行汇总
}

下面我们要汇总(merge)临时文件得到最终的共现矩阵。每个临时文件都是排好序汇总好的,所以每个文件顶端的三元组都是最小的。我们在内存中为每一个临时文件顶端的三元组留一个位置。这样最小的三元组一定在内存中。我们对内存中的三元组使用小顶堆这个数据结构,最小的元素一定在堆顶。这里就不细说这个数据结构了。和之前的write_chunk函数中的old变量一样,这里也有一个old变量存储下一个写出的三元组。当内存中的三元组都比要写出的三元组的大的时候,证明已经汇总好了,文件不会再有和old变量中一样的word pair了,这时就可以写出了。这里直接过一遍merge_files代码

/* Merge [num] sorted files of cooccurrence records */
int merge_files(int num) {//有num个临时文件,对这些文件进行汇总
    int i, size;
    long long counter = 0;
    CRECID *pq, new, old;//pq存所有临时文件的顶端三元组,old存下一个写出的三元组,当一个文件的顶端三元组被拿走,会用new变量临时存储文件的下一个三元组,然后存到pq中去
    char filename[200];
    FILE **fid, *fout;
    fid = malloc(sizeof(FILE) * num);//文件数组指针
    pq = malloc(sizeof(CRECID) * num);//存顶端三元组,小顶堆
    fout = stdout;
    if (verbose > 1) fprintf(stderr, "Merging cooccurrence files: processed 0 lines.");
    
    /* Open all files and add first entry of each to priority queue */
    for (i = 0; i < num; i++) {//首先打开所有的临时文件
        sprintf(filename,"%s_%04d.bin",file_head,i);
        fid[i] = fopen(filename,"rb");
        if (fid[i] == NULL) {fprintf(stderr, "Unable to open file %s.\n",filename); return 1;}
        fread(&new, sizeof(CREC), 1, fid[i]);//读取每个文件的顶端元素
        new.id = i;
        insert(pq,new,i+1);//插入到pq中,姐就是小顶堆中
    }
    
    /* Pop top node, save it in old to see if the next entry is a duplicate */
    size = num;//size存储了当前堆的元素个数,当一个文件读完之后,size减1,size为0就结束了
    old = pq[0];//堆顶的三元组是最小的,先放到old,是下一个要写出的
    i = pq[0].id;//记录堆顶的元素是哪个文件
    delete(pq, size);//删除堆顶元素
    fread(&new, sizeof(CREC), 1, fid[i]);//在把那个文件的下一个三元组读进来
    if (feof(fid[i])) size--;//如果文件中没有三元组了,size-1
    else {
        new.id = i;
        insert(pq, new, size);//把新读入的三元组插入到堆中
    }
    
    /* Repeatedly pop top node and fill priority queue until files have reached EOF */
    while (size > 0) {//进行循环,不断地从临时文件中读取三元组
        counter += merge_write(pq[0], &old, fout); // Only count the lines written to file, not duplicates//merge_write的逻辑和write_chunk写出逻辑一样,如果pq[0]和old的word pair一样就汇总,不一样就写出old,把pq[0]赋给old。
        if ((counter%100000) == 0) if (verbose > 1) fprintf(stderr,"\033[39G%lld lines.",counter);
        i = pq[0].id;//堆顶的三元组是从哪个文件读取的
        delete(pq, size);//删除堆顶,后面的逻辑和循环上面的几行代码一样
        fread(&new, sizeof(CREC), 1, fid[i]);
        if (feof(fid[i])) size--;
        else {
            new.id = i;
            insert(pq, new, size);
        }
    }
    fwrite(&old, sizeof(CREC), 1, fout);//写出old
    fprintf(stderr,"\033[0GMerging cooccurrence files: processed %lld lines.\n",++counter);
    for (i=0;i<num;i++) {//关闭所有的文件
        sprintf(filename,"%s_%04d.bin",file_head,i);
        remove(filename);
    }
    fprintf(stderr,"\n");
    return 0;
}

共现矩阵的构建就到此结束了。为了高效的共现矩阵,这篇文章还有很多的技巧,我们后面再说。


阅读全文
2 0
原创粉丝点击