一个偷偷写的svm库
来源:互联网 发布:linux怎样启动tomcat 编辑:程序博客网 时间:2024/06/05 14:16
今早刚接触一个新的库——dlib(http://dlib.net),讲真,真的很好用。按照官方的介绍,就是:These wrappers provide a portable object oriented interface for networking, multithreading, GUI development, and file browsing. Programs written using them can be compiled under POSIX or MS Windows platforms without changing the code.也就是说,DLIB是一个C ++库,用于开发可移植的应用程序与网络处理,线程,图形界面,数据结构,线性代数,机器学习,XML和文本解析,数值优化,贝叶斯网,和许多其他任务。几乎涉及到数据分析的方方面面了。更重要的是,类似于openCV,它提供很多很详细的example,因此学习起来应该不难。由于今天第一天接触,就根据svm分类的example把它改写成了一个二类分类库,多类分类器以后再慢慢加进去。不过,功能应该不太完善。总之,先放上来吧,以后再慢慢改,目前是涉及到nu和C参数的调整,默认是对ganmma和C调参,因为这两个对结果影响最大嘛。
#include <QtCore>#include <iostream>#include <dlib/svm.h>#include "dlib/rand/rand_kernel_abstract.h"using namespace std;using namespace dlib;//svm二类分类器,调用前请修改nFeatures值;namespace SVM{#define nFeatures 2typedef matrix<double, nFeatures, 1> sample_type;//定义数据类型;typedef radial_basis_kernel<sample_type> kernel_type;//定义核类型;typedef probabilistic_decision_function<kernel_type> probabilistic_funct_type; typedef normalized_function<probabilistic_funct_type> pfunct_type;enum Trainer{CTrainer = 1, NUTrainer = 2};enum LoadType {LoadSamples = 1, LoadTestData = 2};class SVMClassification{public:SVMClassification(){samples.clear();labels.clear();}~SVMClassification(){}bool loadData(const char* fn, int opt = LoadSamples){if(! QFile::exists(fn)){cout << fn << "does not exist!\n";return false;}QFile infile(fn);if (!infile.open(QIODevice::ReadOnly)){cout << fn << "open error!\n";return false;}QTextStream _in(&infile);QString smsg = _in.readLine();QStringList slist;if(opt == LoadSamples){samples.clear();labels.clear();}elsetestData.clear();while(! _in.atEnd()){sample_type samp;smsg = _in.readLine();slist = smsg.split(",");for (int i = 0; i < nFeatures; i ++){samp(i) = slist[i+1].trimmed().toDouble();//cout << samp(i)<<" ";}if(opt == LoadSamples){samples.push_back(samp);labels.push_back(slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0);//cout << (slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0)<<endl;}elsetestData.push_back(samp);}infile.close();return true;}//生成随机的样本数据;bool generateRandomSamples(int num){dlib::rand Rd;for (int i = 0; i < num; i ++){sample_type samp;for (int j = 0; j < nFeatures; j ++){samp(j) = Rd.get_random_gaussian();}samples.push_back(samp);double _label = (double)(Rd.get_random_16bit_number()%2);if(_label == 0)_label =-1;labels.push_back(_label);}cout << "randomly generated "<<num<<" samples\n";return true;}bool normalization(){//归一化;normalizer.train(samples);//获取均值和方差;for (unsigned long i = 0; i < samples.size(); ++i)samples[i] = normalizer(samples[i]); //将样本数据打乱,以用于多次交叉验证;randomize_samples(samples, labels);return true;}//计算最优的参数-gammma, nu;bool findBestParam(int opt = CTrainer){//根据正负标签比例计算参数nu的最大值;const double max_nu = maximum_nu(labels);cout << "max_nu = "<< max_nu <<endl;cout << "doing cross validation..." << endl;matrix<double> best_result(1, 2);best_result = 0;best_gamma = 0.0001, best_nu = 0.0001, best_c = 5;switch(opt){case NUTrainer:for (double gamma = 0.00001; gamma <= 1; gamma *= 5){for (double nu = 0.00001; nu < max_nu; nu *= 5){trainer.set_kernel(kernel_type(gamma));trainer.set_nu(nu);cout << "gamma: " << gamma << " nu: " << nu;matrix<double> result = cross_validate_trainer(trainer, samples, labels, 10);cout << " cross validation accuracy: " << result;if (sum(result) > sum(best_result)){best_result = result;best_gamma = gamma;best_nu = nu;} }}cout << "\nbest gamma: " << best_gamma <<" best nu: " << best_nu<< " best score: "<<best_result<<"mean acc: "<<mean(best_result) << endl; break;case CTrainer:for (double gamma = 0.00001; gamma <= 1; gamma *= 5){for (double _c = 1; _c < 2000; _c *= 2){c_trainer.set_kernel(kernel_type(gamma));c_trainer.set_c(_c);cout << "gamma: " << gamma << " C: " << _c;matrix<double> result = cross_validate_trainer(c_trainer, samples, labels, 10);cout << " cross validation accuracy: " << result;if (sum(result) > sum(best_result)){best_result = result;best_gamma = gamma;best_c = _c;} }}cout << "\nbest gamma: " << best_gamma <<" best c: " << best_c<< " best score: "<<best_result<<"mean acc: "<<mean(best_result) << endl; break;}return true;}void setGamma(double _gamma){best_gamma = _gamma;}void setNu(double _nu){best_nu = _nu;}void setC(double _c){best_c = _c;}bool addSample(double* _pData, double _label){sample_type samp;for (int i = 0; i < nFeatures; i ++){samp(i) = _pData[i];}samples.push_back(samp);labels.push_back(_label);return true;}//clear the samples and labels and reset from a 2-D array;bool setSamples(double** _ppSamples, int _nsamples){samples.clear();labels.clear();for (int i = 0; i < _nsamples; i ++){sample_type samp;for (int j = 0; j < nFeatures; j ++){samp(j) = _ppSamples[i][j];}samples.push_back(samp);labels.push_back(_ppSamples[i][nFeatures]);}return true;}//学习训练分类器;bool learnFunc(int opt = CTrainer){switch(opt){case CTrainer:c_trainer.set_kernel(kernel_type(best_gamma));c_trainer.set_c(best_nu);learned_pfunct.normalizer = normalizer;learned_pfunct.function = train_probabilistic_decision_function(c_trainer, samples, labels, 3);break;case NUTrainer:trainer.set_kernel(kernel_type(best_gamma));trainer.set_nu(best_nu);learned_pfunct.normalizer = normalizer;learned_pfunct.function = train_probabilistic_decision_function(trainer, samples, labels, 3);break;}cout << "\nnumber of support vectors in our learned_pfunct is " << learned_pfunct.function.decision_funct.basis_vectors.size() << endl;return true;}double predictProbability(sample_type _samp){return learned_pfunct(_samp);}//预测概率;double predictProbability(double* _val){sample_type samp;for (int i = 0; i<nFeatures; i ++)samp(i) = _val[i];return learned_pfunct(samp);}//将分类器另存为文件;bool saveLearnedFunc(const char* fn){ serialize(fn) << learned_pfunct; cout <<"saved learned function to "<< fn<<endl; return true;}//从文件中读取分类器;bool loadLearnedFunc(const char* fn){deserialize(fn) >> learned_pfunct;cout <<"loaded learned function from "<< fn<<endl;return true;}//用一定数目的支持向量来交叉验证的精度结果;bool getAccByCrossValidateWithVectors(int nVectors, int opt = CTrainer){cout << "\ncross validation accuracy with only "<<nVectors<<" support vectors: " ;switch(opt){case CTrainer:cout << cross_validate_trainer(reduced2(c_trainer, nVectors), samples, labels, 3);break;case NUTrainer:cout << cross_validate_trainer(reduced2(trainer, nVectors), samples, labels, 3);break;}return true;}private:std::vector<sample_type> samples;std::vector<double> labels;std::vector<sample_type> testData;svm_nu_trainer<kernel_type> trainer;svm_c_trainer<kernel_type> c_trainer;vector_normalizer<sample_type> normalizer;double best_gamma;double best_nu;double best_c;pfunct_type learned_pfunct; protected:};}
0 0
- 一个偷偷写的svm库
- 日本人写的一个SVM,很有启发性
- 一个偷偷修改工作目录的幕后黑手
- 偷偷成熟的表现
- 一个关于SVM的说明
- SVM分类的一个例子
- SVM对手写数字的识别
- 【svm学习】使用svm的一个常见错误
- 偷偷的关闭IE7窗口
- 偷偷的收藏 c++技巧
- 偷偷想过的东西
- 礼拜四log~js一个偷偷技巧
- 使用svm的一个常见错误
- 使用svm的一个常见错误
- 使用svm的一个常见错误
- SVM分类的一个具体例子
- LeftNotEasy写的理解SVM的博文(上)
- LeftNotEasy写的理解SVM的博文(上)
- Intel Edison用mjpg-streamer进行视频传输
- Children of the Candy Corn(图的遍历bfs最小步数)
- @contextmanager:Python实现with结构的好方法
- ajax
- C++中int转化为string
- 一个偷偷写的svm库
- Python version 2.7 required, which was not found in the registry
- Machine Learning --5种距离度量方法
- MySQL5.7.13更改密码时出现ERROR 1054 (42S22): Unknown column 'password' in 'field list'
- jquery eval解析JSON中的注意点介绍
- 洛谷 P1070 道路游戏
- 数据结构-线性表
- php-beast加密php源码
- 93. Restore IP Addresses