词向量源码解析:(6.4)fasttext源码解析之文本分类3

来源:互联网 发布:java string.replace 编辑:程序博客网 时间:2024/06/05 09:48

下面看model的内容,在model中更新参数,也就是输入向量和输出向量。

void Model::update(const std::vector<int32_t>& input, int32_t target, real lr) {
  assert(target >= 0);//首先确认标签是合法的
  assert(target < osz_);
  if (input.size() == 0) return;
  computeHidden(input, hidden_);//对输入词的向量做平均得到hidden向量,和CBOW一样
  if (args_->loss == loss_name::ns) {
    loss_ += negativeSampling(target, lr);
  } else if (args_->loss == loss_name::hs) {
    loss_ += hierarchicalSoftmax(target, lr);
  } else {//由于是文本分类,所以做softmax
    loss_ += softmax(target, lr);//里面会对输出向量更新
  }
  nexamples_ += 1;


  if (args_->model == model_name::sup) {//返回的要调整的值要除以文本的长度,这样每轮对词向量的调整非常小。大家可能疑惑为什么CBOW不这么做。实际上CBOW的调整策略并不严谨。但是因为上下文一般单词不多,所以对CBOW影响不大。当然CBOW也不是完全没道理,可以把CBOW看做是SG的特例。认为上下文中每个单词都是上下文中所有单词的平均
    grad_.mul(1.0 / input.size());
  }
  for (auto it = input.cbegin(); it != input.cend(); ++it) {
    wi_->addRow(grad_, *it, 1.0);//对词(输入)向量更新
  }
}

computeHidden函数就是在求平均值

void Model::computeHidden(const std::vector<int32_t>& input, Vector& hidden) const {
  assert(hidden.size() == hsz_);
  hidden.zero();//初始化0
  for (auto it = input.cbegin(); it != input.cend(); ++it) {
    if(quant_) {
      hidden.addRow(*qwi_, *it);
    } else {
      hidden.addRow(*wi_, *it);//把词向量加到hidden上面
    }
  }
  hidden.mul(1.0 / input.size());//平均数
}

计算softmax

real Model::softmax(int32_t target, real lr) {
  grad_.zero();//
  computeOutputSoftmax();//计算softmax
  for (int32_t i = 0; i < osz_; i++) {//遍历所有输出向量
    real label = (i == target) ? 1.0 : 0.0;
    real alpha = lr * (label - output_[i]);//要更新的梯度
    grad_.addRow(*wo_, i, alpha);//更新累积梯度,将来更新输入向量去
    wo_->addRow(hidden_, i, alpha);//更新输出向量
  }
  return -log(output_[target]);//loss
}

下面看computeOutputSoftmax

void Model::computeOutputSoftmax(Vector& hidden, Vector& output) const {
  if (quant_ && args_->qout) {
    output.mul(*qwo_, hidden);
  } else {
    output.mul(*wo_, hidden);//输出向量(多个输出向量是矩阵)和hidden向量做乘积
  }
  real max = output[0], z = 0.0;
  for (int32_t i = 0; i < osz_; i++) {//softmax常规策略,减去最大值避免over/underflow
    max = std::max(output[i], max);
  }
  for (int32_t i = 0; i < osz_; i++) {
    output[i] = exp(output[i] - max);
    z += output[i];//计算分母
  }
  for (int32_t i = 0; i < osz_; i++) {
    output[i] /= z;//最终的softmax结果
  }
}


void Model::computeOutputSoftmax() {
  computeOutputSoftmax(hidden_, output_);
}

这样文本训练的任务就完成了。大家可能会好奇,ngram加在哪里了,实际上全在dictionary类中处理了。dictionary类直接返回line,line中包括了单词以及ngram的id。下面就看看dictionary类的内容。首先看单词的结构体,在dictionary.c中定义了。

struct entry {
  std::string word;//字符串
  int64_t count;//频数
  entry_type type;//单词还是标签
  std::vector<int32_t> subwords;//subwords的ids
};

再看看dictionary类构建单词。readWord和word2vec一样,从文件流中读取一个单词。

bool Dictionary::readWord(std::istream& in, std::string& word) const
{
  char c;
  std::streambuf& sb = *in.rdbuf();
  word.clear();
  while ((c = sb.sbumpc()) != EOF) {
    if (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == '\v' ||
        c == '\f' || c == '\0') {
      if (word.empty()) {
        if (c == '\n') {
          word += EOS;
          return true;
        }
        continue;
      } else {
        if (c == '\n')
          sb.sungetc();
        return true;
      }
    }
    word.push_back(c);
  }
  // trigger eofbit
  in.get();
  return !word.empty();
}

readFromFile会从文件流中读取一个词典
void Dictionary::readFromFile(std::istream& in) {
  std::string word;
  int64_t minThreshold = 1;
  while (readWord(in, word)) {
    add(word);//向词典添加单词
    if (ntokens_ % 1000000 == 0 && args_->verbose > 1) {
      std::cerr << "\rRead " << ntokens_  / 1000000 << "M words" << std::flush;
    }
    if (size_ > 0.75 * MAX_VOCAB_SIZE) {
      minThreshold++;
      threshold(minThreshold, minThreshold);//reduce单词,和word2vec的reduce一样
    }
  }
  threshold(args_->minCount, args_->minCountLabel);//语料扫描完以后再去掉低频词
  initTableDiscard();//用于subsampling
  initNgrams();//用于得到单词的subword
  if (args_->verbose > 0) {
    std::cerr << "\rRead " << ntokens_  / 1000000 << "M words" << std::endl;
    std::cerr << "Number of words:  " << nwords_ << std::endl;
    std::cerr << "Number of labels: " << nlabels_ << std::endl;
  }
  if (size_ == 0) {
    std::cerr << "Empty vocabulary. Try a smaller -minCount value."
              << std::endl;
    exit(EXIT_FAILURE);
  }
}

下面一个关键的函数就是getLine(),这个函数从数据流中读取一行然后返回这一行单词的id,对于有监督任务(文本分类),还要返回ngram的id。

int32_t Dictionary::getLine(std::istream& in,//输入是文件流
                            std::vector<int32_t>& words,//输出是的得到一行单词的id
                            std::vector<int32_t>& word_hashes,//得到一行单词的哈希值
                            std::vector<int32_t>& labels,//得到一行的label
                            std::minstd_rand& rng) const {
  std::uniform_real_distribution<> uniform(0, 1);


  if (in.eof()) {//文件结束了就从文件的第一行开始读
    in.clear();
    in.seekg(std::streampos(0));
  }


  words.clear();
  labels.clear();
  word_hashes.clear();
  int32_t ntokens = 0;
  std::string token;
  while (readWord(in, token)) {//读一个单词
    int32_t h = find(token);//得到单词的哈希值
    int32_t wid = word2int_[h];//得到单词的id
    if (wid < 0) {
      entry_type type = getType(token);
      if (type == entry_type::word) word_hashes.push_back(hash(token));
      continue;
    }
    entry_type type = getType(wid);//确定是单词还是标签
    ntokens++;
    if (type == entry_type::word && !discard(wid, uniform(rng))) {//如果是单词且没有被subsampling掉
      words.push_back(wid);//存入单词id
      word_hashes.push_back(hash(token));//存入单词哈希
    }
    if (type == entry_type::label) {
      labels.push_back(wid - nwords_);//保存label
    }
    if (token == EOS) break;//如果换行就跳出循环
    if (ntokens > MAX_LINE_SIZE && args_->model != model_name::sup) break;
  }
  return ntokens;// 返回读取了多少的单词
}




int32_t Dictionary::getLine(std::istream& in,
                            std::vector<int32_t>& words,
                            std::vector<int32_t>& labels,
                            std::minstd_rand& rng) const {
  std::vector<int32_t> word_hashes;
  int32_t ntokens = getLine(in, words, word_hashes, labels, rng);
  if (args_->model == model_name::sup ) {//对于有监督任务还要加入ngram
    addWordNgrams(words, word_hashes, args_->wordNgrams);//words中包括了单词和ngram的id
  }
  return ntokens;
}


阅读全文
0 0