词向量源码解析:(6.4)fasttext源码解析之文本分类3
来源:互联网 发布:java string.replace 编辑:程序博客网 时间:2024/06/05 09:48
下面看model的内容,在model中更新参数,也就是输入向量和输出向量。
void Model::update(const std::vector<int32_t>& input, int32_t target, real lr) {
assert(target >= 0);//首先确认标签是合法的
assert(target < osz_);
if (input.size() == 0) return;
computeHidden(input, hidden_);//对输入词的向量做平均得到hidden向量,和CBOW一样
if (args_->loss == loss_name::ns) {
loss_ += negativeSampling(target, lr);
} else if (args_->loss == loss_name::hs) {
loss_ += hierarchicalSoftmax(target, lr);
} else {//由于是文本分类,所以做softmax
loss_ += softmax(target, lr);//里面会对输出向量更新
}
nexamples_ += 1;
if (args_->model == model_name::sup) {//返回的要调整的值要除以文本的长度,这样每轮对词向量的调整非常小。大家可能疑惑为什么CBOW不这么做。实际上CBOW的调整策略并不严谨。但是因为上下文一般单词不多,所以对CBOW影响不大。当然CBOW也不是完全没道理,可以把CBOW看做是SG的特例。认为上下文中每个单词都是上下文中所有单词的平均
grad_.mul(1.0 / input.size());
}
for (auto it = input.cbegin(); it != input.cend(); ++it) {
wi_->addRow(grad_, *it, 1.0);//对词(输入)向量更新
}
}
computeHidden函数就是在求平均值
void Model::computeHidden(const std::vector<int32_t>& input, Vector& hidden) const {
assert(hidden.size() == hsz_);
hidden.zero();//初始化0
for (auto it = input.cbegin(); it != input.cend(); ++it) {
if(quant_) {
hidden.addRow(*qwi_, *it);
} else {
hidden.addRow(*wi_, *it);//把词向量加到hidden上面
}
}
hidden.mul(1.0 / input.size());//平均数
}
计算softmax
real Model::softmax(int32_t target, real lr) {
grad_.zero();//
computeOutputSoftmax();//计算softmax
for (int32_t i = 0; i < osz_; i++) {//遍历所有输出向量
real label = (i == target) ? 1.0 : 0.0;
real alpha = lr * (label - output_[i]);//要更新的梯度
grad_.addRow(*wo_, i, alpha);//更新累积梯度,将来更新输入向量去
wo_->addRow(hidden_, i, alpha);//更新输出向量
}
return -log(output_[target]);//loss
}
下面看computeOutputSoftmax
void Model::computeOutputSoftmax(Vector& hidden, Vector& output) const {
if (quant_ && args_->qout) {
output.mul(*qwo_, hidden);
} else {
output.mul(*wo_, hidden);//输出向量(多个输出向量是矩阵)和hidden向量做乘积
}
real max = output[0], z = 0.0;
for (int32_t i = 0; i < osz_; i++) {//softmax常规策略,减去最大值避免over/underflow
max = std::max(output[i], max);
}
for (int32_t i = 0; i < osz_; i++) {
output[i] = exp(output[i] - max);
z += output[i];//计算分母
}
for (int32_t i = 0; i < osz_; i++) {
output[i] /= z;//最终的softmax结果
}
}
void Model::computeOutputSoftmax() {
computeOutputSoftmax(hidden_, output_);
}
这样文本训练的任务就完成了。大家可能会好奇,ngram加在哪里了,实际上全在dictionary类中处理了。dictionary类直接返回line,line中包括了单词以及ngram的id。下面就看看dictionary类的内容。首先看单词的结构体,在dictionary.c中定义了。
struct entry {
std::string word;//字符串
int64_t count;//频数
entry_type type;//单词还是标签
std::vector<int32_t> subwords;//subwords的ids
};
再看看dictionary类构建单词。readWord和word2vec一样,从文件流中读取一个单词。
bool Dictionary::readWord(std::istream& in, std::string& word) const
{
char c;
std::streambuf& sb = *in.rdbuf();
word.clear();
while ((c = sb.sbumpc()) != EOF) {
if (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == '\v' ||
c == '\f' || c == '\0') {
if (word.empty()) {
if (c == '\n') {
word += EOS;
return true;
}
continue;
} else {
if (c == '\n')
sb.sungetc();
return true;
}
}
word.push_back(c);
}
// trigger eofbit
in.get();
return !word.empty();
}
readFromFile会从文件流中读取一个词典
void Dictionary::readFromFile(std::istream& in) {
std::string word;
int64_t minThreshold = 1;
while (readWord(in, word)) {
add(word);//向词典添加单词
if (ntokens_ % 1000000 == 0 && args_->verbose > 1) {
std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::flush;
}
if (size_ > 0.75 * MAX_VOCAB_SIZE) {
minThreshold++;
threshold(minThreshold, minThreshold);//reduce单词,和word2vec的reduce一样
}
}
threshold(args_->minCount, args_->minCountLabel);//语料扫描完以后再去掉低频词
initTableDiscard();//用于subsampling
initNgrams();//用于得到单词的subword
if (args_->verbose > 0) {
std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::endl;
std::cerr << "Number of words: " << nwords_ << std::endl;
std::cerr << "Number of labels: " << nlabels_ << std::endl;
}
if (size_ == 0) {
std::cerr << "Empty vocabulary. Try a smaller -minCount value."
<< std::endl;
exit(EXIT_FAILURE);
}
}
下面一个关键的函数就是getLine(),这个函数从数据流中读取一行然后返回这一行单词的id,对于有监督任务(文本分类),还要返回ngram的id。
int32_t Dictionary::getLine(std::istream& in,//输入是文件流
std::vector<int32_t>& words,//输出是的得到一行单词的id
std::vector<int32_t>& word_hashes,//得到一行单词的哈希值
std::vector<int32_t>& labels,//得到一行的label
std::minstd_rand& rng) const {
std::uniform_real_distribution<> uniform(0, 1);
if (in.eof()) {//文件结束了就从文件的第一行开始读
in.clear();
in.seekg(std::streampos(0));
}
words.clear();
labels.clear();
word_hashes.clear();
int32_t ntokens = 0;
std::string token;
while (readWord(in, token)) {//读一个单词
int32_t h = find(token);//得到单词的哈希值
int32_t wid = word2int_[h];//得到单词的id
if (wid < 0) {
entry_type type = getType(token);
if (type == entry_type::word) word_hashes.push_back(hash(token));
continue;
}
entry_type type = getType(wid);//确定是单词还是标签
ntokens++;
if (type == entry_type::word && !discard(wid, uniform(rng))) {//如果是单词且没有被subsampling掉
words.push_back(wid);//存入单词id
word_hashes.push_back(hash(token));//存入单词哈希
}
if (type == entry_type::label) {
labels.push_back(wid - nwords_);//保存label
}
if (token == EOS) break;//如果换行就跳出循环
if (ntokens > MAX_LINE_SIZE && args_->model != model_name::sup) break;
}
return ntokens;// 返回读取了多少的单词
}
int32_t Dictionary::getLine(std::istream& in,
std::vector<int32_t>& words,
std::vector<int32_t>& labels,
std::minstd_rand& rng) const {
std::vector<int32_t> word_hashes;
int32_t ntokens = getLine(in, words, word_hashes, labels, rng);
if (args_->model == model_name::sup ) {//对于有监督任务还要加入ngram
addWordNgrams(words, word_hashes, args_->wordNgrams);//words中包括了单词和ngram的id
}
return ntokens;
}
- 词向量源码解析:(6.4)fasttext源码解析之文本分类3
- 词向量源码解析:(6.2)fasttext源码解析之文本分类1
- 词向量源码解析:(6.3)fasttext源码解析之文本分类2
- 词向量源码解析:(6.5)fasttext源码解析之文本分类4
- 词向量源码解析:(6.1)fasttext源码解析
- 词向量源码解析:(6.6)fasttext源码解析之词向量1
- 词向量源码解析:(6.7)fasttext源码解析之词向量1
- FastText 词向量与文本分类
- 词向量源码解析:(2.1)word2vec源码解析
- 词向量源码解析:(2.7)word2vec源码解析小结
- 词向量源码解析:(3.1)GloVe源码解析
- 词向量源码解析:(3.6)GloVe源码解析小结
- 词向量源码解析:(4.1)hyperwords源码解析
- 词向量源码解析:(4.9)hyperwords源码解析小结
- 词向量源码解析:(5.1)ngram2vec源码解析
- 词向量源码解析:(5.12)ngram2vec源码解析小结
- 词向量源码解析:(2.2)word2vec源码解析之word2phrase
- 词向量源码解析:(2.3)word2vec源码解析之word2vec
- 访问控制符
- Ubuntu 下安装Source Insight
- 读取配置中心更新后数据抛出异常
- Hive(上)--Hive介绍及部署
- 抽象工厂模式
- 词向量源码解析:(6.4)fasttext源码解析之文本分类3
- tcpdump介绍
- Xshell配置Xagent登陆服务器
- php数据导入到Excel中
- java中的==、equals()、hashCode()源码分析
- 【Outlook】2013 outlook working offline
- 变形课
- xpath语法总结
- 集合框架专题4—LinkedList