词向量源码解析:(3.5)GloVe源码解析之glove
来源:互联网 发布:js是哪里的车牌 编辑:程序博客网 时间:2024/06/05 18:43
和绝大多数的词向量不同,glove的目标是通过训练词向量和上下文向量,使得它们能够重构共现矩阵。glove训练部分的代码风格和word2vec中训练部分的代码风格如出一辙。有了之前看word2vec的基础,很容易就能看懂glove是怎么做的了。glove在三元组上面进行训练。三元组的数据结构依然和原来一样。
typedef struct cooccur_rec {
int word1;
int word2;
real val;
} CREC;
我们首先看看glove中要训练的参数,看initialize_parameters可以发现,跟word2vec中的参数基本一样,都是一份词向量参数和一份上下文向量参数。维度上glove多了一个bias,所以有一点区别,下面是代码。
void initialize_parameters() {
long long a, b;
vector_size++; // Temporarily increment to allocate space for bias//词向量多一维给bias
/* Allocate space for word vectors and context word vectors, and correspodning gradsq */
a = posix_memalign((void **)&W, 128, 2 * vocab_size * (vector_size + 1) * sizeof(real)); // Might perform better than malloc//词向量和上下文向量参数,这个乘2把两部分都包括了。这里vector_size加1,写的可能有问题,应该不需要加1
if (W == NULL) {
fprintf(stderr, "Error allocating memory for W\n");
exit(1);
}
a = posix_memalign((void **)&gradsq, 128, 2 * vocab_size * (vector_size + 1) * sizeof(real)); // Might perform better than malloc//glove用adagrad做梯度下降,还要为每个参数存梯度累积值。
if (gradsq == NULL) {
fprintf(stderr, "Error allocating memory for gradsq\n");
exit(1);
}
for (b = 0; b < vector_size; b++) for (a = 0; a < 2 * vocab_size; a++) W[a * vector_size + b] = (rand() / (real)RAND_MAX - 0.5) / vector_size;//初始化参数
for (b = 0; b < vector_size; b++) for (a = 0; a < 2 * vocab_size; a++) gradsq[a * vector_size + b] = 1.0; // So initial value of eta is equal to initial learning rate
vector_size--;
}
下面是train_glove函数,开启多线程调用glove_thread函数去训练,跟word2vec套路一样
/* Train model */
int train_glove() {
long long a, file_size;
int save_params_return_code;
int b;
FILE *fin;
real total_cost = 0;
fprintf(stderr, "TRAINING MODEL\n");
fin = fopen(input_file, "rb");//打开被打乱的三元组文件
if (fin == NULL) {fprintf(stderr,"Unable to open cooccurrence file %s.\n",input_file); return 1;}
fseeko(fin, 0, SEEK_END);
file_size = ftello(fin);
num_lines = file_size/(sizeof(CREC)); // Assuming the file isn't corrupt and consists only of CREC's//一共有多少个三元组
fclose(fin);
fprintf(stderr,"Read %lld lines.\n", num_lines);
if (verbose > 1) fprintf(stderr,"Initializing parameters...");
initialize_parameters();
if (verbose > 1) fprintf(stderr,"done.\n");
if (verbose > 0) fprintf(stderr,"vector size: %d\n", vector_size);
if (verbose > 0) fprintf(stderr,"vocab size: %lld\n", vocab_size);
if (verbose > 0) fprintf(stderr,"x_max: %lf\n", x_max);
if (verbose > 0) fprintf(stderr,"alpha: %lf\n", alpha);
pthread_t *pt = (pthread_t *)malloc(num_threads * sizeof(pthread_t));//多线程
lines_per_thread = (long long *) malloc(num_threads * sizeof(long long));//每个线程处理的三元组个数
time_t rawtime;
struct tm *info;
char time_buffer[80];
// Lock-free asynchronous SGD
for (b = 0; b < num_iter; b++) {//和word2vec不一样,轮数的循环写在了外面
total_cost = 0;
for (a = 0; a < num_threads - 1; a++) lines_per_thread[a] = num_lines / num_threads;//每个线程处理多少个三元组
lines_per_thread[a] = num_lines / num_threads + num_lines % num_threads;
long long *thread_ids = (long long*)malloc(sizeof(long long) * num_threads);//标识第几个线程
for (a = 0; a < num_threads; a++) thread_ids[a] = a;
for (a = 0; a < num_threads; a++) pthread_create(&pt[a], NULL, glove_thread, (void *)&thread_ids[a]);//开启多线程
for (a = 0; a < num_threads; a++) pthread_join(pt[a], NULL);//等待线程的结束
for (a = 0; a < num_threads; a++) total_cost += cost[a];
free(thread_ids);
time(&rawtime);
info = localtime(&rawtime);
strftime(time_buffer,80,"%x - %I:%M.%S%p", info);
fprintf(stderr, "%s, iter: %03d, cost: %lf\n", time_buffer, b+1, total_cost/num_lines);
if (checkpoint_every > 0 && (b + 1) % checkpoint_every == 0) {//glove还支持在每一轮都输出一个结果,这也是它把轮数循环写在外面的原因
fprintf(stderr," saving itermediate parameters for iter %03d...", b+1);
save_params_return_code = save_params(b+1);
if (save_params_return_code != 0)
return save_params_return_code;
fprintf(stderr,"done.\n");
}
}
free(pt);
free(lines_per_thread);
return save_params(0);
}
下面详细的说依稀glove_thread,这里面是glove训练的核心代码。
/* Train the GloVe model */
void *glove_thread(void *vid) {
long long a, b ,l1, l2;
long long id = *(long long*)vid;
CREC cr;
real diff, fdiff, temp1, temp2;
FILE *fin;
fin = fopen(input_file, "rb");
fseeko(fin, (num_lines / num_threads * id) * (sizeof(CREC)), SEEK_SET); //Threads spaced roughly equally throughout file//首先要找到这个线程从哪里开始训练三元组
cost[id] = 0;
real* W_updates1 = (real*)malloc(vector_size * sizeof(real));
real* W_updates2 = (real*)malloc(vector_size * sizeof(real));
for (a = 0; a < lines_per_thread[id]; a++) {//过一遍这个线程需要训练的所有三元组
fread(&cr, sizeof(CREC), 1, fin);//首先读取一个三元组
if (feof(fin)) break;
if (cr.word1 < 1 || cr.word2 < 1) { continue; }
/* Get location of words in W & gradsq */
l1 = (cr.word1 - 1LL) * (vector_size + 1); // cr word indices start at 1//找到这个三元组中两个单词的id
l2 = ((cr.word2 - 1LL) + vocab_size) * (vector_size + 1); // shift by vocab_size to get separate vectors for context words
/* Calculate cost, save diff for gradients */
diff = 0;
for (b = 0; b < vector_size; b++) diff += W[b + l1] * W[b + l2]; // dot product of word and context word vector//词向量和上下文向量的内积
diff += W[vector_size + l1] + W[vector_size + l2] - log(cr.val); // add separate bias for each word//再加上连个单词对应的bias,应该和两个单词共现的次数val的log值尽可能接近
fdiff = (cr.val > x_max) ? diff : pow(cr.val / x_max, alpha) * diff; // multiply weighting function (f) with diff//同时每个三元组的重要程度不一样,val高的三元组更重要一些
// Check for NaN and inf() in the diffs.
if (isnan(diff) || isnan(fdiff) || isinf(diff) || isinf(fdiff)) {
fprintf(stderr,"Caught NaN in diff for kdiff for thread. Skipping update");
continue;
}
cost[id] += 0.5 * fdiff * diff; // weighted squared error//平方误差
/* Adaptive gradient updates */
fdiff *= eta; // for ease in calculating gradient
real W_updates1_sum = 0;
real W_updates2_sum = 0;
for (b = 0; b < vector_size; b++) {//更新词向量和上下文向量的值
// learning rate times gradient for word vectors
temp1 = fdiff * W[b + l2];
temp2 = fdiff * W[b + l1];
// adaptive updates//adagrad梯度下降
W_updates1[b] = temp1 / sqrt(gradsq[b + l1]);//词向量要调整的值,adagrad公式
W_updates2[b] = temp2 / sqrt(gradsq[b + l2]);//上下文向量要调整的值
W_updates1_sum += W_updates1[b];
W_updates2_sum += W_updates2[b];
gradsq[b + l1] += temp1 * temp1;//更新梯度累积值,越来越大,使得learning rate越来越小,和word2vec的机制类似
gradsq[b + l2] += temp2 * temp2;
}
if (!isnan(W_updates1_sum) && !isinf(W_updates1_sum) && !isnan(W_updates2_sum) && !isinf(W_updates2_sum)) {
for (b = 0; b < vector_size; b++) {
W[b + l1] -= W_updates1[b];//更新参数值
W[b + l2] -= W_updates2[b];
}
}
// updates for bias terms//更新bias的值
W[vector_size + l1] -= check_nan(fdiff / sqrt(gradsq[vector_size + l1]));
W[vector_size + l2] -= check_nan(fdiff / sqrt(gradsq[vector_size + l2]));
fdiff *= fdiff;
gradsq[vector_size + l1] += fdiff;
gradsq[vector_size + l2] += fdiff;
}
free(W_updates1);
free(W_updates2);
fclose(fin);
pthread_exit(NULL);
}
glove的训练就完成了,和word2vec代码其实是高度的相似的,word2vec是对于每个word pair更新一次词向量和上下文向量,glove是根据每个三元组去更新一次词向量和上下文向量。最后介绍一个glove1如何保存模型得到的参数。glove中不仅可以保存词向量还可以保存上下文向量,以及可以保存每个单词的bias。代码比较长,不过逻辑很简单。
int save_params(int nb_iter) {
/*
* nb_iter is the number of iteration (= a full pass through the cooccurrence matrix).
* nb_iter > 0 => checkpointing the intermediate parameters, so nb_iter is in the filename of output file.
* else => saving the final paramters, so nb_iter is ignored.
*/
//目前内存中还没有词典,词典会把单词和id对应上
long long a, b;
char format[20];
char output_file[MAX_STRING_LENGTH], output_file_gsq[MAX_STRING_LENGTH];
char *word = malloc(sizeof(char) * MAX_STRING_LENGTH + 1);
FILE *fid, *fout, *fgs;
if (use_binary > 0) { // Save parameters in binary file//二进制存储
if (nb_iter <= 0)
sprintf(output_file,"%s.bin",save_W_file);
else
sprintf(output_file,"%s.%03d.bin",save_W_file,nb_iter);
fout = fopen(output_file,"wb");
if (fout == NULL) {fprintf(stderr, "Unable to open file %s.\n",save_W_file); return 1;}
for (a = 0; a < 2 * (long long)vocab_size * (vector_size + 1); a++) fwrite(&W[a], sizeof(real), 1,fout);//二进制的情况下就存储词向量,不需要读取单词
fclose(fout);
if (save_gradsq > 0) {//还可以存梯度的累积值
if (nb_iter <= 0)
sprintf(output_file_gsq,"%s.bin",save_gradsq_file);
else
sprintf(output_file_gsq,"%s.%03d.bin",save_gradsq_file,nb_iter);
fgs = fopen(output_file_gsq,"wb");
if (fgs == NULL) {fprintf(stderr, "Unable to open file %s.\n",save_gradsq_file); return 1;}
for (a = 0; a < 2 * (long long)vocab_size * (vector_size + 1); a++) fwrite(&gradsq[a], sizeof(real), 1,fgs);
fclose(fgs);
}
}
if (use_binary != 1) { // Save parameters in text file//非二进制的情况下,允许存储不同的参数
if (nb_iter <= 0)
sprintf(output_file,"%s.txt",save_W_file);
else
sprintf(output_file,"%s.%03d.txt",save_W_file,nb_iter);
if (save_gradsq > 0) {
if (nb_iter <= 0)
sprintf(output_file_gsq,"%s.txt",save_gradsq_file);
else
sprintf(output_file_gsq,"%s.%03d.txt",save_gradsq_file,nb_iter);
fgs = fopen(output_file_gsq,"wb");
if (fgs == NULL) {fprintf(stderr, "Unable to open file %s.\n",save_gradsq_file); return 1;}
}
fout = fopen(output_file,"wb");
if (fout == NULL) {fprintf(stderr, "Unable to open file %s.\n",save_W_file); return 1;}
fid = fopen(vocab_file, "r");//首先要读入字典
sprintf(format,"%%%ds",MAX_STRING_LENGTH);
if (fid == NULL) {fprintf(stderr, "Unable to open file %s.\n",vocab_file); return 1;}
if (write_header) fprintf(fout, "%ld %d\n", vocab_size, vector_size);//一般词向量写出的第一行都是词典中单词个数以及词向量的维度
for (a = 0; a < vocab_size; a++) {//遍历词典,读取一个单词,写一个单词,后面跟着写向量
if (fscanf(fid,format,word) == 0) return 1;
// input vocab cannot contain special <unk> keyword
if (strcmp(word, "<unk>") == 0) return 1;
fprintf(fout, "%s",word);//每行以单词开头
if (model == 0) { // Save all parameters (including bias)//存词向量上下文向量以及包括bias,和上面的写到文件第一行的vector_size对不上了。算是一个bug
for (b = 0; b < (vector_size + 1); b++) fprintf(fout," %lf", W[a * (vector_size + 1) + b]);
for (b = 0; b < (vector_size + 1); b++) fprintf(fout," %lf", W[(vocab_size + a) * (vector_size + 1) + b]);
}
if (model == 1) // Save only "word" vectors (without bias)//只存词向量没有bias
for (b = 0; b < vector_size; b++) fprintf(fout," %lf", W[a * (vector_size + 1) + b]);
if (model == 2) // Save "word + context word" vectors (without bias)//存词向量和上下文向量的求和
for (b = 0; b < vector_size; b++) fprintf(fout," %lf", W[a * (vector_size + 1) + b] + W[(vocab_size + a) * (vector_size + 1) + b]);
fprintf(fout,"\n");
if (save_gradsq > 0) { // Save gradsq
fprintf(fgs, "%s",word);
for (b = 0; b < (vector_size + 1); b++) fprintf(fgs," %lf", gradsq[a * (vector_size + 1) + b]);
for (b = 0; b < (vector_size + 1); b++) fprintf(fgs," %lf", gradsq[(vocab_size + a) * (vector_size + 1) + b]);
fprintf(fgs,"\n");
}
if (fscanf(fid,format,word) == 0) return 1; // Eat irrelevant frequency entry//文件词典中还有单词的频数,这里没用,读取以后不需要处理
}
if (use_unk_vec) {//后面是处理oov的逻辑,无足轻重。
real* unk_vec = (real*)calloc((vector_size + 1), sizeof(real));
real* unk_context = (real*)calloc((vector_size + 1), sizeof(real));
word = "<unk>";
int num_rare_words = vocab_size < 100 ? vocab_size : 100;
for (a = vocab_size - num_rare_words; a < vocab_size; a++) {
for (b = 0; b < (vector_size + 1); b++) {
unk_vec[b] += W[a * (vector_size + 1) + b] / num_rare_words;
unk_context[b] += W[(vocab_size + a) * (vector_size + 1) + b] / num_rare_words;
}
}
fprintf(fout, "%s",word);
if (model == 0) { // Save all parameters (including bias)
for (b = 0; b < (vector_size + 1); b++) fprintf(fout," %lf", unk_vec[b]);
for (b = 0; b < (vector_size + 1); b++) fprintf(fout," %lf", unk_context[b]);
}
if (model == 1) // Save only "word" vectors (without bias)
for (b = 0; b < vector_size; b++) fprintf(fout," %lf", unk_vec[b]);
if (model == 2) // Save "word + context word" vectors (without bias)
for (b = 0; b < vector_size; b++) fprintf(fout," %lf", unk_vec[b] + unk_context[b]);
fprintf(fout,"\n");
free(unk_vec);
free(unk_context);
}
fclose(fid);
fclose(fout);
if (save_gradsq > 0) fclose(fgs);
}
return 0;
}
- 词向量源码解析:(3.5)GloVe源码解析之glove
- 词向量源码解析:(3.1)GloVe源码解析
- 词向量源码解析:(3.6)GloVe源码解析小结
- 词向量源码解析:(3.2)GloVe源码解析之vocab_count
- 词向量源码解析:(3.3)GloVe源码解析之cooccur
- 词向量源码解析:(3.4)GloVe源码解析之shuffle
- GloVe 词向量模型
- 词向量之加载word2vec和glove
- GloVe 教程之实战入门+python gensim 词向量
- CS224n笔记三之词向量模型与GloVe
- 词向量源码解析:(6.6)fasttext源码解析之词向量1
- 词向量源码解析:(6.7)fasttext源码解析之词向量1
- 词向量源码解析:(2.1)word2vec源码解析
- 词向量源码解析:(2.7)word2vec源码解析小结
- 词向量源码解析:(4.1)hyperwords源码解析
- 词向量源码解析:(4.9)hyperwords源码解析小结
- 词向量源码解析:(5.1)ngram2vec源码解析
- 词向量源码解析:(5.12)ngram2vec源码解析小结
- longest-valid-parentheses
- Mysql与Oracle的区别
- node学习笔记<入门级>
- java多线程同步的五种方法
- bash shell 清空文件的方法
- 词向量源码解析:(3.5)GloVe源码解析之glove
- ActivityCallcycleCallbacks基本解析
- JAVA多线程实现的三种方式
- AndroidManifest.xml文件中属性记录
- Android权限记录
- 支付宝接口使用文档说明 支付宝异步通知(notify_url)与return_url.
- CoordinatorLayout父布局的Behavior
- 在SpringMVC中使用JSON需要导入的几个jar包
- Frameworks detected: Android framework is detected in the project