一个偷偷写的svm库

来源:互联网 发布:linux怎样启动tomcat 编辑:程序博客网 时间:2024/06/05 14:16

今早刚接触一个新的库——dlib(http://dlib.net),讲真,真的很好用。按照官方的介绍,就是:These wrappers provide a portable object oriented interface for networking, multithreading, GUI development, and file browsing. Programs written using them can be compiled under POSIX or MS Windows platforms without changing the code.也就是说,DLIB是一个C ++库,用于开发可移植的应用程序与网络处理,线程,图形界面,数据结构,线性代数,机器学习,XML和文本解析,数值优化,贝叶斯网,和许多其他任务。几乎涉及到数据分析的方方面面了。更重要的是,类似于openCV,它提供很多很详细的example,因此学习起来应该不难。由于今天第一天接触,就根据svm分类的example把它改写成了一个二类分类库,多类分类器以后再慢慢加进去。不过,功能应该不太完善。总之,先放上来吧,以后再慢慢改,目前是涉及到nu和C参数的调整,默认是对ganmma和C调参,因为这两个对结果影响最大嘛。


#include <QtCore>#include <iostream>#include <dlib/svm.h>#include "dlib/rand/rand_kernel_abstract.h"using namespace std;using namespace dlib;//svm二类分类器,调用前请修改nFeatures值;namespace SVM{#define nFeatures 2typedef matrix<double, nFeatures, 1> sample_type;//定义数据类型;typedef radial_basis_kernel<sample_type> kernel_type;//定义核类型;typedef probabilistic_decision_function<kernel_type> probabilistic_funct_type;  typedef normalized_function<probabilistic_funct_type> pfunct_type;enum Trainer{CTrainer = 1, NUTrainer = 2};enum LoadType {LoadSamples = 1, LoadTestData = 2};class SVMClassification{public:SVMClassification(){samples.clear();labels.clear();}~SVMClassification(){}bool loadData(const char* fn, int opt = LoadSamples){if(! QFile::exists(fn)){cout << fn << "does not exist!\n";return false;}QFile infile(fn);if (!infile.open(QIODevice::ReadOnly)){cout << fn << "open error!\n";return false;}QTextStream _in(&infile);QString smsg = _in.readLine();QStringList slist;if(opt == LoadSamples){samples.clear();labels.clear();}elsetestData.clear();while(! _in.atEnd()){sample_type samp;smsg = _in.readLine();slist = smsg.split(",");for (int i = 0; i < nFeatures; i ++){samp(i) = slist[i+1].trimmed().toDouble();//cout << samp(i)<<" ";}if(opt == LoadSamples){samples.push_back(samp);labels.push_back(slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0);//cout << (slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0)<<endl;}elsetestData.push_back(samp);}infile.close();return true;}//生成随机的样本数据;bool generateRandomSamples(int num){dlib::rand Rd;for (int i = 0; i < num; i ++){sample_type samp;for (int j = 0; j < nFeatures; j ++){samp(j) = Rd.get_random_gaussian();}samples.push_back(samp);double _label = (double)(Rd.get_random_16bit_number()%2);if(_label == 0)_label =-1;labels.push_back(_label);}cout << "randomly generated "<<num<<" samples\n";return true;}bool normalization(){//归一化;normalizer.train(samples);//获取均值和方差;for (unsigned long i = 0; i < samples.size(); ++i)samples[i] = normalizer(samples[i]); //将样本数据打乱,以用于多次交叉验证;randomize_samples(samples, labels);return true;}//计算最优的参数-gammma, nu;bool findBestParam(int opt = CTrainer){//根据正负标签比例计算参数nu的最大值;const double max_nu = maximum_nu(labels);cout << "max_nu = "<< max_nu <<endl;cout << "doing cross validation..." << endl;matrix<double> best_result(1, 2);best_result = 0;best_gamma = 0.0001, best_nu = 0.0001, best_c = 5;switch(opt){case NUTrainer:for (double gamma = 0.00001; gamma <= 1; gamma *= 5){for (double nu = 0.00001; nu < max_nu; nu *= 5){trainer.set_kernel(kernel_type(gamma));trainer.set_nu(nu);cout << "gamma: " << gamma << "    nu: " << nu;matrix<double> result = cross_validate_trainer(trainer, samples, labels, 10);cout << "     cross validation accuracy: " << result;if (sum(result) > sum(best_result)){best_result = result;best_gamma = gamma;best_nu = nu;} }}cout << "\nbest gamma: " << best_gamma <<"      best nu: " << best_nu<< "      best score: "<<best_result<<"mean acc:  "<<mean(best_result) << endl; break;case CTrainer:for (double gamma = 0.00001; gamma <= 1; gamma *= 5){for (double _c = 1; _c < 2000; _c *= 2){c_trainer.set_kernel(kernel_type(gamma));c_trainer.set_c(_c);cout << "gamma: " << gamma << "    C: " << _c;matrix<double> result = cross_validate_trainer(c_trainer, samples, labels, 10);cout << "     cross validation accuracy: " << result;if (sum(result) > sum(best_result)){best_result = result;best_gamma = gamma;best_c = _c;} }}cout << "\nbest gamma: " << best_gamma <<"      best c: " << best_c<< "      best score: "<<best_result<<"mean acc:  "<<mean(best_result) << endl; break;}return true;}void setGamma(double _gamma){best_gamma = _gamma;}void setNu(double _nu){best_nu = _nu;}void setC(double _c){best_c = _c;}bool addSample(double* _pData, double _label){sample_type samp;for (int i = 0; i < nFeatures; i ++){samp(i) = _pData[i];}samples.push_back(samp);labels.push_back(_label);return true;}//clear the samples and labels and reset from a 2-D array;bool setSamples(double** _ppSamples, int _nsamples){samples.clear();labels.clear();for (int i = 0; i < _nsamples; i ++){sample_type samp;for (int j = 0; j < nFeatures; j ++){samp(j) = _ppSamples[i][j];}samples.push_back(samp);labels.push_back(_ppSamples[i][nFeatures]);}return true;}//学习训练分类器;bool learnFunc(int opt = CTrainer){switch(opt){case CTrainer:c_trainer.set_kernel(kernel_type(best_gamma));c_trainer.set_c(best_nu);learned_pfunct.normalizer = normalizer;learned_pfunct.function = train_probabilistic_decision_function(c_trainer, samples, labels, 3);break;case NUTrainer:trainer.set_kernel(kernel_type(best_gamma));trainer.set_nu(best_nu);learned_pfunct.normalizer = normalizer;learned_pfunct.function = train_probabilistic_decision_function(trainer, samples, labels, 3);break;}cout << "\nnumber of support vectors in our learned_pfunct is " << learned_pfunct.function.decision_funct.basis_vectors.size() << endl;return true;}double predictProbability(sample_type _samp){return learned_pfunct(_samp);}//预测概率;double predictProbability(double* _val){sample_type samp;for (int i = 0; i<nFeatures; i ++)samp(i) = _val[i];return learned_pfunct(samp);}//将分类器另存为文件;bool saveLearnedFunc(const char* fn){ serialize(fn) << learned_pfunct; cout <<"saved learned function to "<< fn<<endl; return true;}//从文件中读取分类器;bool loadLearnedFunc(const char* fn){deserialize(fn) >> learned_pfunct;cout <<"loaded learned function from "<< fn<<endl;return true;}//用一定数目的支持向量来交叉验证的精度结果;bool getAccByCrossValidateWithVectors(int nVectors, int opt = CTrainer){cout << "\ncross validation accuracy with only "<<nVectors<<" support vectors: " ;switch(opt){case CTrainer:cout << cross_validate_trainer(reduced2(c_trainer, nVectors), samples, labels, 3);break;case NUTrainer:cout << cross_validate_trainer(reduced2(trainer, nVectors), samples, labels, 3);break;}return true;}private:std::vector<sample_type> samples;std::vector<double> labels;std::vector<sample_type> testData;svm_nu_trainer<kernel_type> trainer;svm_c_trainer<kernel_type> c_trainer;vector_normalizer<sample_type> normalizer;double best_gamma;double best_nu;double best_c;pfunct_type learned_pfunct; protected:};}



0 0
原创粉丝点击