词向量源码解析:(6.5)fasttext源码解析之文本分类4
来源:互联网 发布:cats vs dogs数据下载 编辑:程序博客网 时间:2024/05/22 15:27
下面说一下测试和预测,测试和预测分别调用了test和predict函数
int main(int argc, char** argv) {
if (argc < 2) {
printUsage();
exit(EXIT_FAILURE);
}
std::string command(argv[1]);
if (command == "skipgram" || command == "cbow" || command == "supervised") {
train(argc, argv);
} else if (command == "test") {
test(argc, argv);
} else if (command == "quantize") {
quantize(argc, argv);
} else if (command == "print-word-vectors") {
printWordVectors(argc, argv);
} else if (command == "print-sentence-vectors") {
printSentenceVectors(argc, argv);
} else if (command == "print-ngrams") {
printNgrams(argc, argv);
} else if (command == "nn") {
nn(argc, argv);
} else if (command == "analogies") {
analogies(argc, argv);
} else if (command == "predict" || command == "predict-prob" ) {
predict(argc, argv);
} else {
printUsage();
exit(EXIT_FAILURE);
}
return 0;
}
下面是test函数,先读取模型,然后实例化fasttext进行预测
void test(int argc, char** argv) {
if (argc < 4 || argc > 5) {
printTestUsage();
exit(EXIT_FAILURE);
}
int32_t k = 1;
if (argc >= 5) {
k = atoi(argv[4]);
}
FastText fasttext;
fasttext.loadModel(std::string(argv[2]));
std::string infile(argv[3]);
if (infile == "-") {
fasttext.test(std::cin, k);
} else {
std::ifstream ifs(infile);
if (!ifs.is_open()) {
std::cerr << "Test file cannot be opened!" << std::endl;
exit(EXIT_FAILURE);
}
fasttext.test(ifs, k);
ifs.close();
}
exit(0);
}
下面先看fasttext中的loadModel
void FastText::loadModel(const std::string& filename) {
std::ifstream ifs(filename, std::ifstream::binary);
if (!ifs.is_open()) {
std::cerr << "Model file cannot be opened for loading!" << std::endl;
exit(EXIT_FAILURE);
}
if (!checkModel(ifs)) {
std::cerr << "Model file has wrong file format!" << std::endl;
exit(EXIT_FAILURE);
}
loadModel(ifs);
ifs.close();
}
void FastText::loadModel(std::istream& in) {
args_ = std::make_shared<Args>();
dict_ = std::make_shared<Dictionary>(args_);
input_ = std::make_shared<Matrix>();
output_ = std::make_shared<Matrix>();
qinput_ = std::make_shared<QMatrix>();
qoutput_ = std::make_shared<QMatrix>();
args_->load(in);
dict_->load(in);
bool quant_input;
in.read((char*) &quant_input, sizeof(bool));
if (quant_input) {
quant_ = true;
qinput_->load(in);
} else {
input_->load(in);
}
in.read((char*) &args_->qout, sizeof(bool));
if (quant_ && args_->qout) {
qoutput_->load(in);
} else {
output_->load(in);
}
model_ = std::make_shared<Model>(input_, output_, args_, 0);
model_->quant_ = quant_;
model_->setQuantizePointer(qinput_, qoutput_, args_->qout);
if (args_->model == model_name::sup) {
model_->setTargetCounts(dict_->getCounts(entry_type::label));
} else {
model_->setTargetCounts(dict_->getCounts(entry_type::word));
}
}
下面看test函数,在这个例子中k是1
void FastText::test(std::istream& in, int32_t k) {
int32_t nexamples = 0, nlabels = 0;
double precision = 0.0;
std::vector<int32_t> line, labels;
while (in.peek() != EOF) {//遍历每一行
dict_->getLine(in, line, labels, model_->rng);//读取一行
if (labels.size() > 0 && line.size() > 0) {
std::vector<std::pair<real, int32_t>> modelPredictions;//考虑到了多标签
model_->predict(line, k, modelPredictions);//这里k等于1,就一个标签
for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) {
if (std::find(labels.begin(), labels.end(), it->second) != labels.end()) {
precision += 1.0;
}
}
nexamples++;
nlabels += labels.size();
}
}
std::cout << "N" << "\t" << nexamples << std::endl;
std::cout << std::setprecision(3);
std::cout << "P@" << k << "\t" << precision / (k * nexamples) << std::endl;
std::cout << "R@" << k << "\t" << precision / nlabels << std::endl;
std::cerr << "Number of examples: " << nexamples << std::endl;
}
再看关键代码predict函数
void Model::predict(const std::vector<int32_t>& input, int32_t k,
std::vector<std::pair<real, int32_t>>& heap,
Vector& hidden, Vector& output) const {
assert(k > 0);
heap.reserve(k + 1);
computeHidden(input, hidden);//通过词向量平均得到hidden向量
if (args_->loss == loss_name::hs) {
dfs(k, 2 * osz_ - 2, 0.0, heap, hidden);
} else {
findKBest(k, heap, hidden, output);//得到预测类别
}
std::sort_heap(heap.begin(), heap.end(), comparePairs);
}
void Model::predict(const std::vector<int32_t>& input, int32_t k,
std::vector<std::pair<real, int32_t>>& heap) {
predict(input, k, heap, hidden_, output_);//hidden_和output_是model类自带的,一会儿用于存储结果
}
最后再看看findKBest
void Model::findKBest(int32_t k, std::vector<std::pair<real, int32_t>>& heap,
Vector& hidden, Vector& output) const {
computeOutputSoftmax(hidden, output);//softmax的运算结果
for (int32_t i = 0; i < osz_; i++) {//遍历得到的softmax值,选取k个最大的
if (heap.size() == k && log(output[i]) < heap.front().first) {
continue;
}
heap.push_back(std::make_pair(log(output[i]), i));//如果大于第k个值就更新堆
std::push_heap(heap.begin(), heap.end(), comparePairs);
if (heap.size() > k) {
std::pop_heap(heap.begin(), heap.end(), comparePairs);
heap.pop_back();
}
}
}
- 词向量源码解析:(6.5)fasttext源码解析之文本分类4
- 词向量源码解析:(6.2)fasttext源码解析之文本分类1
- 词向量源码解析:(6.3)fasttext源码解析之文本分类2
- 词向量源码解析:(6.4)fasttext源码解析之文本分类3
- 词向量源码解析:(6.1)fasttext源码解析
- 词向量源码解析:(6.6)fasttext源码解析之词向量1
- 词向量源码解析:(6.7)fasttext源码解析之词向量1
- FastText 词向量与文本分类
- 词向量源码解析:(2.1)word2vec源码解析
- 词向量源码解析:(2.7)word2vec源码解析小结
- 词向量源码解析:(3.1)GloVe源码解析
- 词向量源码解析:(3.6)GloVe源码解析小结
- 词向量源码解析:(4.1)hyperwords源码解析
- 词向量源码解析:(4.9)hyperwords源码解析小结
- 词向量源码解析:(5.1)ngram2vec源码解析
- 词向量源码解析:(5.12)ngram2vec源码解析小结
- 词向量源码解析:(2.2)word2vec源码解析之word2phrase
- 词向量源码解析:(2.3)word2vec源码解析之word2vec
- 队列篇(二)----环形队列的应用(C++版)
- Solr Filter过滤器介绍
- 《红楼梦》的庭院叙事观点
- android studio 设置项目编码
- 装配Bean——通过XML装配bean
- 词向量源码解析:(6.5)fasttext源码解析之文本分类4
- 洛谷 P2962 高斯消元解异或方程
- window下注册服务的命令
- Hisat2下载
- BZOJ 2142 礼物 (扩展Lucas)
- Java后台框架篇--Spring Integration模块集成
- XPath+第二节
- Eclipse下如何打开Servers窗口及应用
- 小游戏