LR 做多分类的笔记
来源:互联网 发布:ubuntu rar 解压命令 编辑:程序博客网 时间:2024/04/29 14:58
1. 从概率的角度出发,推断一个样本的后验概率为:
其中:
4.63 式可以有比较简洁的形式,例如:线性表达式。
2. 假定P(x| Ck) 为正态分布,
则 lnp(x|Ck)p(Ck) 可以表示为线性的表达式如下:
3. 求解模型参数:
4. 本质上分类走概率模型比较靠谱,直观上某一个地方的点密集,可以说明在该类的概率搞。使用平方误差是,距离无法衡量到某类的距离。
但是 SVC 也用了类似的距离, 但是 SVC 只用了支持向量,且投影到高维空间。
5. sample code:
#include <string>#include <vector>#include <cmath>#include <map>#include "base/flags.h"#include "base/string_util.h"#include "utils/hash_tables.h"#include "common/file/simple_line_reader.h"DEFINE_string(train_path, "./test.txt", "trainning file");DEFINE_double(lambda, 0.001, "the weight of regularation");DEFINE_double(alpha, 0.1, "the learning rate");DEFINE_int32(n, 5000, "iteration times");// origianl data;struct DataSample { std::string label; double predict_prob; utils::hash_map<std::string, double> features; void AddFeature(const std::string& fn, const double& v) { features[fn] = v; }};// inner_class label are: [0, 1, 2, (label_count_-1)]class TrainningDataSet {public: TrainningDataSet() { label_count_ = 0; inner_feature_map_["cb"] = 0; outer_feature_map_[0] = "cb"; feature_count_ = 1; } // format like: "A 1:0.2 2:0.4 3:77 4:0.3" bool LoadSamplesFromFile(const std::string& file_path) { file::SimpleLineReader line_reader; line_reader.OpenOrDie(file_path); std::vector<std::string> lines; line_reader.ReadLines(&lines); for (size_t i = 0; i < lines.size(); ++i) { std::vector<std::string> parts; DataSample sample; SplitString(lines[i], ' ', &parts); sample.label = parts[0]; AddLabel(parts[0]); for (size_t j = 1; j < parts.size(); ++j) { std::vector<std::string> fn_v; SplitString(parts[j], ':', &fn_v); if (fn_v.size() != 2) { continue; } double v = 0.0f; StringToDouble(fn_v[1], &v); sample.AddFeature(fn_v[0], v); AddFeature(fn_v[0]); } samples_.push_back(sample); } return true; } void Train() { AllocAuxParam(); TrainInternal(FLAGS_n); Predict(); FreeAuxParam(); } void Predict() { double prob[100]; // Xn for (auto it = samples_.begin(); it != samples_.end(); ++it) { for (int k = 0; k < label_count_; ++k) { prob[k] = w[k][0]*1.0f; for (auto sit = it->features.begin(); sit != it->features.end(); ++sit) { prob[k] += sit->second*w[k][inner_feature_map_[sit->first]]; } prob[k] = exp(prob[k]); } double total_exp = 0.0f; for (int k = 0; k < label_count_; ++k) { total_exp += prob[k]; } it->predict_prob = prob[inner_label_map_[it->label]]/total_exp; VLOG(0) << "sampel predict [" << it->label << "]: " << it->predict_prob; } } void TrainInternal(int32 count) { for (int32 i = 0; i < count; ++i) { //VLOG(0) << "training iteration: " << (i+1); // caculate post_prob: P(Ci | x) for (size_t n = 0; n < samples_.size(); ++n) { const DataSample& sample = samples_[n]; for (int32 k = 0; k < label_count_; ++k) { // W*X, x[0] = 1; post_prob[n][k] = 1.0f*w[k][0]; for (auto it = sample.features.begin(); it != sample.features.end(); ++it) { std::string outter_feature_idx = it->first; double val = it->second; int32 feature_idx = inner_feature_map_[outter_feature_idx]; post_prob[n][k] += val*w[k][feature_idx]; } post_prob[n][k] = exp(post_prob[n][k]); } double exp_total = 0.0f; for (int32 k = 0; k < label_count_; ++k) { exp_total += post_prob[n][k]; } for (int32 k = 0; k < label_count_; ++k) { post_prob[n][k] /= exp_total; //VLOG(0) << "P(C" << k << "|X" << n << ") = " << post_prob[n][k]; } } // caculate gradient, (E)/(Wk) for (int32 k = 0; k < label_count_; ++k) { for (int32 d = 0; d < feature_count_; ++d) { grad[k][d] = 0.0f; } // iteration on every sample for (size_t n = 0; n < samples_.size(); ++n) { // iteration on every dimension double Tnk = GetTnk(n, k); double Ynk = post_prob[n][k]; for (int32 d = 0; d < feature_count_; ++d) { grad[k][d] += GetXnd(n, d)*(Ynk - Tnk); } } //std::string w_str; for (int32 d = 0; d < feature_count_; ++d) { grad[k][d] += w[k][d]*FLAGS_lambda; w[k][d] -= FLAGS_alpha*grad[k][d]; //w_str.append(outer_feature_map_[d]).append(":").append(DoubleToString(w[k][d])).append(","); } //VLOG(0) << "[w" << k << "]: " << w_str; } } } void Dump() { utils::hash_map<std::string, int>::iterator it; for (it = inner_label_map_.begin(); it != inner_label_map_.end(); ++it) { VLOG(0) << "lable: " << it->first << ", " << it->second; } for (it = inner_feature_map_.begin(); it != inner_feature_map_.end(); ++it) { VLOG(0) << "featu: " << it->first << ", " << it->second; } }private: double** w; // w[k][d] update: w[k] = w[k] - alpha*(grad[k]) double** grad; // grad[k][d] update: grad[k] = (Ynk - Tnk)*Xn*lambda; double** post_prob; // post_prob[n][k] update: p[n][k] = P(Ck | xn); std::vector<DataSample> samples_; utils::hash_map<std::string, int> inner_label_map_; // 'A' -> 1 'B' -> 2 utils::hash_map<int, std::string> outer_label_map_; // 1 -> 'A' 2 -> 'B' int32 label_count_; int AddLabel(const std::string& outer_label) { utils::hash_map<std::string, int>::iterator it = inner_label_map_.find(outer_label); if (it == inner_label_map_.end()) { inner_label_map_[outer_label] = label_count_; outer_label_map_[label_count_] = outer_label; label_count_++; } return it->second; } utils::hash_map<std::string, int> inner_feature_map_; // "const_bias" -> 0, "1" -> 1, "2" -> 2, "url_host_big" -> feature_count utils::hash_map<int, std::string> outer_feature_map_; int32 feature_count_; void AddFeature(const std::string& feature_name) { utils::hash_map<std::string, int>::iterator it = inner_feature_map_.find(feature_name); if (it == inner_feature_map_.end()) { inner_feature_map_[feature_name] = feature_count_; outer_feature_map_[feature_count_] = feature_name; feature_count_++; } } double GetXnd(const int32& n, const int32& d) { double ret = 0.0f; if (d == 0) { ret = 1.0f; } else { DataSample& sample = samples_[n]; std::string& ol = outer_feature_map_[d]; auto it = sample.features.find(ol); if (it != sample.features.end()) { ret = it->second; } } //VLOG(0) << "X(" << n << "," << d << ") = " << ret; return ret; } double GetTnk(const int32& n, const int32& k) { double ret = 0.0f; if (inner_label_map_[samples_[n].label] == k) { ret = 1.0f; } //VLOG(0) << "T(" << n << "," << k << ") = " << ret; return ret; } void FreeAuxParam() { for (int k = 0; k < label_count_; ++k) { delete w[k]; delete grad[k]; } delete w; delete grad; for (size_t n = 0; n < samples_.size(); ++n) { delete post_prob[n]; } delete post_prob; } void AllocAuxParam() { // w[k][d], grad[k][d] w = new double*[label_count_]; grad = new double*[label_count_]; for (int k = 0; k < label_count_; ++k) { w[k] = new double[feature_count_]; grad[k] = new double[feature_count_]; for (int f = 0; f < feature_count_; ++f) { w[k][f] = 0.0f; } } // post_prob[n][k] post_prob = new double*[samples_.size()]; for (size_t n = 0; n < samples_.size(); ++n) { post_prob[n] = new double[label_count_]; } }};int main(int argc, char* argv[]) { base::ParseCommandLineFlags(&argc, &argv, false); TrainningDataSet tds; tds.LoadSamplesFromFile(FLAGS_train_path); tds.Dump(); tds.Train(); return 0;}
测试样例:
A 1:0.20 2:0.70
A 1:0.10 2:0.80
A 1:0.30 2:0.60
A 1:0.05 2:0.94
A 1:0.77 2:0.22
A 1:0.44 2:0.55
B 1:0.20 2:0.81
B 1:0.30 2:0.71
B 1:1.00 2:0.01
B 1:0.50 2:0.51
B 1:0.40 2:0.65
B 1:0.70 2:0.40
结果:
I0930 09:53:14.443656 32046 lr.cc:100] sampel predict [A]: 0.926048
I0930 09:53:14.443763 32046 lr.cc:100] sampel predict [A]: 0.932346
I0930 09:53:14.443836 32046 lr.cc:100] sampel predict [A]: 0.919215
I0930 09:53:14.443896 32046 lr.cc:100] sampel predict [A]: 0.592285
I0930 09:53:14.443940 32046 lr.cc:100] sampel predict [A]: 0.421599
I0930 09:53:14.443987 32046 lr.cc:100] sampel predict [A]: 0.499967
I0930 09:53:14.444035 32046 lr.cc:100] sampel predict [B]: 0.569759
I0930 09:53:14.444111 32046 lr.cc:100] sampel predict [B]: 0.593065
I0930 09:53:14.444164 32046 lr.cc:100] sampel predict [B]: 0.740223
I0930 09:53:14.444211 32046 lr.cc:100] sampel predict [B]: 0.638351
I0930 09:53:14.444258 32046 lr.cc:100] sampel predict [B]: 0.816627
I0930 09:53:14.444300 32046 lr.cc:100] sampel predict [B]: 0.955107
中间的几个点很接近,所以 prob 不是那么高~
- LR 做多分类的笔记
- LR进行多分类theano代码分析
- LR:做关联的几种方法
- 文本分类学习笔记(4)- LR模型
- 基于LR的新闻多分类(基于spark2.1.0, 附完整代码)
- 基于LR的新闻多分类(基于spark2.1.0, 附完整代码)
- Python机器学习库sklearn里利用LR模型进行三分类(多分类)的原理
- LR笔记
- LR笔记
- 【ML笔记】LR和SVM的异同
- LR做关联
- LR的HTTP中做文本检查点的函数
- LR如何做windows下 cpu 的使用
- LR做接口测试时出现的错误-84800
- 利用LR做性能测试中出现的常见问题解决方案
- LR其实是可以做一下特征离散化的
- 利用LR做性能测试中出现的常见问题解决方案
- 利用LR做性能测试中出现的常见问题解决方案
- Android WebView开发问题及优化汇总
- Linux 的多线程编程的高效开发经验
- Linq中的Find方法
- perl去除重复行和排序
- 发邮件同一天用户可妈以后
- LR 做多分类的笔记
- 知识积累之线程的睡眠时间
- struts2 package 属性说明
- createrepo用法和索引文件分析
- Ubuntu12.04中NS2的安装
- 对内存对齐的理解
- Jquery easyui datagrid清除缓存使用技巧
- exej4的使用
- windows OpenCV 2.3.1 Python 2.7配置