词向量源码解析:(3.4)GloVe源码解析之shuffle

来源:互联网 发布:计算机编程语言排名 编辑:程序博客网 时间:2024/05/16 11:08

这部分代码的功能是打乱共现矩阵中三元组的顺序。cooccur生成的三元组是排好序的。我没有尝试过用排好序的训练能得到什么结果。一个很简单的shuffle方法是把所有的都读入内存,在内存中打乱。但是在我们的内存不足以装下所有的三元组的时候怎么办?这里采用了一个两阶段的方法做shuffle,先局部shuffle,得到很多临时文件。字第二阶段,再均匀的从每个临时文件中读取三元组,放入内存。shuffle以后再写出的就是最终shuffle好的共现矩阵了。下面介绍几个关键函数

共现矩阵的三元组的数据结构,和之前一样

typedef struct cooccur_rec {
    int word1;
    int word2;
    real val;
} CREC;

array是三元组数组,size是数组长度,把内存中的array写到文件中,不需要像之前cooccur那样还汇总

/* Write contents of array to binary file */
int write_chunk(CREC *array, long size, FILE *fout) {
    long i = 0;
    for (i = 0; i < size; i++) fwrite(&array[i], sizeof(CREC), 1, fout);
    return 0;
}

对内存中的三元组进行打乱顺序shuffle,逻辑很简单,就是在交换顺序。

/* Fisher-Yates shuffle */
void shuffle(CREC *array, long n) {
    long i, j;
    CREC tmp;
    for (i = n - 1; i > 0; i--) {
        j = rand_long(i + 1);
        tmp = array[j];
        array[j] = array[i];
        array[i] = tmp;
    }
}

刚才说了shuffle分成两个阶段,第一个阶段是shuffle_by_chunks,顺序的读入文件中的三元组,局部排序

/* Shuffle large input stream by splitting into chunks */
int shuffle_by_chunks() {
    long i = 0, l = 0;
    int fidcounter = 0;
    char filename[MAX_STRING_LENGTH];
    CREC *array;
    FILE *fin = stdin, *fid;
    array = malloc(sizeof(CREC) * array_size);//存储三元组
    
    fprintf(stderr,"SHUFFLING COOCCURRENCES\n");
    if (verbose > 0) fprintf(stderr,"array size: %lld\n", array_size);
    sprintf(filename,"%s_%04d.bin",file_head, fidcounter);
    fid = fopen(filename,"w");
    if (fid == NULL) {
        fprintf(stderr, "Unable to open file %s.\n",filename);
        return 1;
    }
    if (verbose > 1) fprintf(stderr, "Shuffling by chunks: processed 0 lines.");
    
    while (1) { //Continue until EOF//循环读入文件中的三元组
        if (i >= array_size) {// If array is full, shuffle it and save to temporary file//内存满了的话就写出
            shuffle(array, i-2);//打乱
            l += i;
            if (verbose > 1) fprintf(stderr, "\033[22Gprocessed %ld lines.", l);
            write_chunk(array,i,fid);//写出
            fclose(fid);
            fidcounter++;
            sprintf(filename,"%s_%04d.bin",file_head, fidcounter);//再向新的文件写入
            fid = fopen(filename,"w");
            if (fid == NULL) {
                fprintf(stderr, "Unable to open file %s.\n",filename);
                return 1;
            }
            i = 0;
        }
        fread(&array[i], sizeof(CREC), 1, fin);
        if (feof(fin)) break;
        i++;
    }
    shuffle(array, i-2); //Last chunk may be smaller than array_size
    write_chunk(array,i,fid);
    l += i;
    if (verbose > 1) fprintf(stderr, "\033[22Gprocessed %ld lines.\n", l);
    if (verbose > 1) fprintf(stderr, "Wrote %d temporary file(s).\n", fidcounter + 1);
    fclose(fid);
    free(array);
    return shuffle_merge(fidcounter + 1); // Merge and shuffle together temporary files//进入第二阶段
}

第二阶段再对之前的临时文件打乱一次,最后输出最终打乱的三元组,作为glove的输入

int shuffle_merge(int num) {//一共有num个临时文件
    long i, j, k, l = 0;
    int fidcounter = 0;
    CREC *array;
    char filename[MAX_STRING_LENGTH];
    FILE **fid, *fout = stdout;
    
    array = malloc(sizeof(CREC) * array_size);
    fid = malloc(sizeof(FILE) * num);
    for (fidcounter = 0; fidcounter < num; fidcounter++) { //num = number of temporary files to merge//打开所有的临时文件
        sprintf(filename,"%s_%04d.bin",file_head, fidcounter);
        fid[fidcounter] = fopen(filename, "rb");
        if (fid[fidcounter] == NULL) {
            fprintf(stderr, "Unable to open file %s.\n",filename);
            return 1;
        }
    }
    if (verbose > 0) fprintf(stderr, "Merging temp files: processed %ld lines.", l);
    
    while (1) { //Loop until EOF in all files
        i = 0;
        //Read at most array_size values into array, roughly array_size/num from each temp file//从每个临时文件读入array_size/num个三元组,这样内存中会基本均匀的有所有临时文件的三元组,打乱他们再写出
        for (j = 0; j < num; j++) {
            if (feof(fid[j])) continue;
            for (k = 0; k < array_size / num; k++){//从每个文件读入固定数目的三元组,保证内存不会满
                fread(&array[i], sizeof(CREC), 1, fid[j]);
                if (feof(fid[j])) break;
                i++;
            }
        }
        if (i == 0) break;//如果读不出来了(也就是都读完了)就结束了
        l += i;
        shuffle(array, i-1); // Shuffles lines between temp files//打乱
        write_chunk(array,i,fout);//写出
        if (verbose > 0) fprintf(stderr, "\033[31G%ld lines.", l);
    }
    fprintf(stderr, "\033[0GMerging temp files: processed %ld lines.", l);
    for (fidcounter = 0; fidcounter < num; fidcounter++) {
        fclose(fid[fidcounter]);
        sprintf(filename,"%s_%04d.bin",file_head, fidcounter);
        remove(filename);
    }
    fprintf(stderr, "\n\n");
    free(array);
    return 0;
}

阅读全文
0 0
原创粉丝点击