Svm算法

来源:互联网 发布:类似blued的软件 编辑:程序博客网 时间:2024/05/18 09:01

关于SVM 算法的详解可以看点击打开链接,下面的svm算法实现基于OpenCv的CvSvm实现:


#ifndef __SVM_TRAIN__#define __SVM_TRAIN__#include<opencv2\opencv.hpp>#include <opencv2/core/core.hpp>//#include <opencv2/highgui/highgui.hpp>#include <opencv2/ml/ml.hpp>//#include"ml.h"typedef struct { CvMat *data_mat; CvMat *class_mat;  //数据的类别}Vector_data,*vector_data;class MySvm: public CvSVM  {  public:      int get_alpha_count()      {          return this->sv_total;      }       int get_sv_dim()      {          return this->var_all;      }       int get_sv_count()      {         return this->decision_func->sv_count;     }      double* get_alpha()     {         return this->decision_func->alpha;      }        float** get_sv()      {          return this->sv;      }        float get_rho()      {          return this->decision_func->rho;      }  }; class svm_Train{public:svm_Train(int sample_count,int sample_size);~svm_Train();void svm_SetData(float** data, float* label);void svm_StartTrain();void svm_Save();float svm_Perdict(CvMat* TestMat );vector_data vector_element;MySvm* tsvm;CvSVMParams svm_param;int sample_count;int sample_size;bool isTrain,isSetData,isSave;};#endif

#include"svm_Train.h"svm_Train::svm_Train(int sample_count,int sample_size){this->vector_element=(vector_data)malloc(sizeof(Vector_data));this->vector_element->data_mat=cvCreateMat(sample_count,sample_size,CV_32FC1);this->vector_element->class_mat=cvCreateMat(sample_count,1,CV_32FC1);/*this->vector_element->data_mat=cvCreateMat(sample_size,sample_count,CV_32FC1);this->vector_element->class_mat=cvCreateMat(1,sample_count,CV_32FC1);*/this->tsvm=new MySvm();this->sample_count=sample_count;this->sample_size=sample_size;////this->svm_param.svm_type = CvSVM::C_SVC; ///*this->svm_param.svm_type = CvSVM::C_SVC;*/ ////this->svm_param.kernel_type = CvSVM::LINEAR;///*this->svm_param.kernel_type = CvSVM::RBF;*///this->svm_param.kernel_type =CvSVM::POLY;////this->svm_param.kernel_type =CvSVM::SIGMOID;//this->svm_param.gamma = 1./this->sample_size;    //this->svm_param.nu = 0.5;    //this->svm_param.C = 10;    //this->svm_param.term_crit.epsilon = 0.0001;    //this->svm_param.term_crit.max_iter = 1000;    //this->svm_param.term_crit.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS; //this->svm_param.coef0=4;//this->svm_param.degree=4;    this->svm_param.svm_type = CvSVM::EPS_SVR;this->svm_param.kernel_type = CvSVM::RBF;    /*this->svm_param.kernel_type = CvSVM::POLY;  */    this->svm_param.C =300;      this->svm_param.p = 1e-3;      this->svm_param.gamma =0.29;//1./this->sample_size; /*this->svm_param.coef0=3;this->svm_param.degree=3.8;*//*this->svm_param.svm_type = CvSVM::NU_SVR;      this->svm_param.kernel_type = CvSVM::LINEAR;      this->svm_param.C =1;      this->svm_param.p = 1e-3;      this->svm_param.gamma = 5;this->svm_param.nu= 0.9;*/isTrain=false;isSetData=false;isSave=false;}void svm_Train::svm_SetData( float** data, float* label){/*cvInitMatHeader(vector_element->class_mat,sample_count,1,CV_32FC1,label);      cvInitMatHeader(vector_element->data_mat,sample_count,sample_size,CV_32FC1,data);  */    //cvInitMatHeader(&mtestd,1,2,CV_32FC1,testd);for(int i=0;i<sample_count;i++){ vector_element->class_mat->data.fl[i]=label[i]; for(int j=0;j<sample_size;j++){     vector_element->data_mat->data.fl[i*sample_size+j]=data[i][j];}}/*for(int i=0;i<sample_count;i++){ vector_element->class_mat->data.fl[i]=label[i]; for(int j=0;j<sample_size;j++){     vector_element->data_mat->data.fl[j*sample_count+i]=data[i][j];}}*/isSetData=true; }void svm_Train::svm_StartTrain(){if(isSetData){tsvm->train(vector_element->data_mat, vector_element->class_mat,0,0,svm_param); isTrain=true;}elseprintf("Please set train data first\n");}void svm_Train::svm_Save(){if(isTrain)tsvm->save("detector.xml", 0); //For predictionelseprintf("Please training data first\n");isSave=true;}float svm_Train::svm_Perdict(CvMat* TestMat){float value;if(isTrain){value=tsvm->predict(TestMat);}elseprintf("Please training data first\n");return value;}svm_Train::~svm_Train(){isTrain=false;isSetData=false;isSave=false;if(vector_element) free(vector_element);if(tsvm) free(tsvm);}


0 0