LibSVM C/C++

来源:互联网 发布:淘宝上哪家店牛仔裤好 编辑:程序博客网 时间:2024/06/10 18:10


本系列文章由 @YhL_Leo 出品,转载请注明出处。
文章链接: http://blog.csdn.net/yhl_leo/article/details/50179779


LibSVM的库的svm.h头文件中定义了四个主要结构体:

1 训练模型的结构体

struct svm_problem{    int l;                // total number of samples    double *y;            // label of each sample    struct svm_node **x;  // feature vector of each sample};

样本的类别通常使用+1-1进行标识。如果样本的类别,则分类的准确率也就无法计算。

2 数据节点的结构体

struct svm_node{    int index;         double value;};

数据组织结构如图1所示:

3 模型参数结构体

struct svm_parameter{    int svm_type;    int kernel_type;    int degree; /* for poly */    double gamma;   /* for poly/rbf/sigmoid */    double coef0;   /* for poly/sigmoid */    /* these are for training only */    double cache_size; /* in MB */    double eps; /* stopping criteria */    double C;   /* for C_SVC, EPSILON_SVR and NU_SVR */    int nr_weight;      /* for C_SVC */    int *weight_label;  /* for C_SVC */    double* weight;     /* for C_SVC */    double nu;  /* for NU_SVC, ONE_CLASS, and NU_SVR */    double p;   /* for EPSILON_SVR */    int shrinking;  /* use the shrinking heuristics */    int probability; /* do probability estimates */};

其中,各个参数的含义为:

-s svm_type : set type of SVM (default 0)    0 -- C-SVC    1 -- nu-SVC    2 -- one-class SVM    3 -- epsilon-SVR    4 -- nu-SVR-t kernel_type : set type of kernel function (default 2)    0 -- linear: u'*v    1 -- polynomial: (gamma*u'*v + coef0)^degree    2 -- radial basis function: exp(-gamma*|u-v|^2)    3 -- sigmoid: tanh(gamma*u'*v + coef0)-d degree : set degree in kernel function (default 3)-g gamma : set gamma in kernel function (default 1/num_features)-r coef0 : set coef0 in kernel function (default 0)-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)-m cachesize : set cache memory size in MB (default 100)-e epsilon : set tolerance of termination criterion (default 0.001)-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)-b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)-wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)

SVM模型类型和核函数类型:

enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR }; /* svm_type */enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */

4 训练输出模型结构体

struct svm_model{    struct svm_parameter param; /* parameter */    int nr_class;       /* number of classes, = 2 in regression/one class svm */    int l;          /* total #SV */    struct svm_node **SV;       /* SVs (SV[l]) */    double **sv_coef;   /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */    double *rho;        /* constants in decision functions (rho[k*(k-1)/2]) */    double *probA;      /* pariwise probability information */    double *probB;    int *sv_indices;        /* sv_indices[0,...,nSV-1] are values in [1,...,num_traning_data] to indicate SVs in the training set */    /* for classification only */    int *label;     /* label of each class (label[k]) */    int *nSV;       /* number of SVs for each class (nSV[k]) */                /* nSV[0] + nSV[1] + ... + nSV[k-1] = l */    /* XXX */    int free_sv;        /* 1 if svm_model is created by svm_load_model*/                /* 0 if svm_model is created by svm_train */};

5 使用方法

LibSVM提供的样本特征集heart_scale为例,首先需要读取样本特征数据,可以利用svm-train.c文件中的read_problem函数,为了方便使用,对其进行了重写改写:

// TrainingDataLoad.h/*    Load training data from svm format file.    - Editor: Yahui Liu.    - Data:   2015-11-30    - Email:  yahui.cvrs@gmail.com    - Address: Computer Vision and Remote Sensing(CVRS), Lab.**/#ifndef TRAINING_DATA_LOAD_H#define TRAINING_DATA_LOAD_H#pragma once#include <stdio.h>#include <stdlib.h>#include <ctype.h>#include <iostream>#include <vector>#include <string>#include <fstream>#include <errno.h>#include "svm.h"//#include "svm-scale.c"using namespace std;#define MAX_LINE_LEN 1024class TrainingDateLoad{public:    TrainingDateLoad()    {        line = NULL;    }    ~TrainingDateLoad()    {        line = NULL;    }public:    char* line;// public://  static struct svm_parameter _paramInit;public:    /*! load svm model */    void loadModel( std::string filename,  struct svm_model*& model);    /*! skip the target */    void svmSkipTarget( char*& p);    /* skip the element */    void svmSkipElement( char*& p);    void initialParams( struct svm_parameter& param );      /*! load training data */    void readProblem( std::string filename, struct svm_problem& prob, struct svm_parameter& param );    char* readline(FILE *input);     void exit_input_error(int line_num)    {        cout << "Wrong input format at line: " << line_num << endl;        exit(1);    }};#endif // TRAINING_DATA_LOAD_H
// TrainingDataLoad.cpp#include "TrainingDataLoad.h"void TrainingDateLoad::loadModel(std::string filename, struct svm_model*& model){    model = svm_load_model(filename.c_str());}void TrainingDateLoad::svmSkipTarget(char*& p){    while(isspace(*p)) ++p;    while(!isspace(*p)) ++p;}void TrainingDateLoad::svmSkipElement(char*& p){    while(*p!=':') ++p;    ++p;    while(isspace(*p)) ++p;    while(*p && !isspace(*p)) ++p;}void TrainingDateLoad::initialParams( struct svm_parameter& param ){    // default values    param.svm_type = C_SVC;    param.kernel_type = RBF;    param.degree = 3;    param.gamma = 0;    // 1/num_features    param.coef0 = 0;    param.nu = 0.5;    param.cache_size = 100;    param.C = 1;    param.eps = 1e-3;    param.p = 0.1;    param.shrinking = 1;    param.probability = 0;    param.nr_weight = 0;    param.weight_label = NULL;    param.weight = NULL;}void TrainingDateLoad::readProblem( std::string filename,     struct svm_problem& prob, struct svm_parameter& param ){    int max_index, inst_max_index, i;    size_t elements, j;    FILE *fp = fopen(filename.c_str(),"r");    char *endptr;    char *idx, *val, *label;    if(fp == NULL)    {        fprintf(stderr,"can't open input file %s\n",filename);        exit(1);    }    prob.l = 0;    elements = 0;    line = new char[MAX_LINE_LEN];    while(readline(fp)!=NULL)    {        char *p = strtok(line," \t"); // label        // features        while(1)        {            p = strtok(NULL," \t");            if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature                break;            ++elements;        }        ++elements;        ++prob.l;    }    rewind(fp);    prob.y = new double[prob.l];    prob.x = new struct svm_node *[prob.l];    struct svm_node *x_space = new struct svm_node[elements];    max_index = 0;    j=0;    for(i=0;i<prob.l;i++)    {        inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0        readline(fp);        prob.x[i] = &x_space[j];        label = strtok(line," \t\n");        if(label == NULL) // empty line            exit_input_error(i+1);        prob.y[i] = strtod(label,&endptr);        if(endptr == label || *endptr != '\0')            exit_input_error(i+1);        while(1)        {            idx = strtok(NULL,":");            val = strtok(NULL," \t");            if(val == NULL)                break;            errno = 0;            x_space[j].index = (int) strtol(idx,&endptr,10);            if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)                exit_input_error(i+1);            else                inst_max_index = x_space[j].index;            errno = 0;            x_space[j].value = strtod(val,&endptr);            if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))                exit_input_error(i+1);            ++j;        }        if(inst_max_index > max_index)            max_index = inst_max_index;        x_space[j++].index = -1;    }    if(param.gamma == 0 && max_index > 0)        param.gamma = 1.0/max_index;    if(param.kernel_type == PRECOMPUTED)        for(i=0;i<prob.l;i++)        {            if (prob.x[i][0].index != 0)            {                fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");                exit(1);            }            if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)            {                fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");                exit(1);            }        }        fclose(fp);}char* TrainingDateLoad::readline(FILE *input){    int len;    if(fgets(line,MAX_LINE_LEN,input) == NULL)        return NULL;    int max_line_len = MAX_LINE_LEN;    while(strrchr(line,'\n') == NULL)    {        max_line_len *= 2;        line = (char *) realloc(line,max_line_len);        len = (int) strlen(line);        if(fgets(line+len,max_line_len-len,input) == NULL)            break;    }    return line;}

将样本训练与预测进行改写:

// LibSVMTools.h/*    LibSVM train and predict tools.    - Editor: Yahui Liu.    - Data:   2015-12-3    - Email:  yahui.cvrs@gmail.com    - Address: Computer Vision and Remote Sensing(CVRS), Lab.**/#ifndef LIBSVM_TOOL_H#define LIBSVM_TOOL_H#pragma once#include <iostream>#include <string>#include "svm.h"#include "TrainingDataLoad.h"class LibSVMTools{public:    LibSVMTools(){}    ~LibSVMTools(){}public:    /*!        - featureFile: features of images saved in libsvm format.        - saveModelFile: save the trained model file.    **/    void libSvmTrain(std::string featureFile, std::string saveModelFile);    /*!        - featureFile: features of images saved in libsvm format.        - modelFile: libsvm trained model.        - savePredictFile: save the predicting results.    **/    void libSvmPredict(std::string featureFile, std::string modelFile, std::string savePredictFile);};#endif // LIBSVM_TOOL_H
// LibSVMTools.cpp#include "LibSVMTools.h"void LibSVMTools::libSvmTrain(std::string featureFile, std::string saveModelFile){    struct svm_parameter param;    struct svm_problem prob;    TrainingDateLoad* trainData = new TrainingDateLoad;    trainData->initialParams( param );    trainData->readProblem(featureFile, prob, param);    const char*errorMsg = svm_check_parameter(&prob, &param);    if ( errorMsg )    {        cout << errorMsg << endl;        return;    }    struct svm_model *model = svm_train(&prob, &param);#if 1    cout << "svm_type: " << model->param.svm_type << endl <<        "kernel_type: " << model->param.kernel_type << endl <<        "gamma: " << model->param.gamma << endl <<        "nr_class: " << model->nr_class << endl <<        "total_sv: " << model->l << endl <<        "rho: " << model->rho[0] << endl <<        "label: " << model->label[0] << " " << model->label[1] << endl <<        "nr_sv: " << model->nSV[0] << " " << model->nSV[1] << endl;#endif    int saveModel = svm_save_model( saveModelFile.c_str(), model );}void LibSVMTools::libSvmPredict(std::string featureFile,     std::string modelFile, std::string savePredictFile){    struct svm_parameter param;    struct svm_problem prob;    TrainingDateLoad * trainData = new TrainingDateLoad;    trainData->initialParams( param );    trainData->readProblem(featureFile, prob, param);    struct svm_model* model;    trainData->loadModel(modelFile.c_str(), model);    float correct(0.0);     // all correct    float uncorrect_1(0.0); // pos to neg    float uncorrect_2(0.0); // neg to pos    if ( prob.l )    {        const int nCount = prob.l;;        ofstream outfile( savePredictFile, ios::out );        for( int i=0; i<nCount; i++ )        {            double label = svm_predict(model, prob.x[i]);            if ( label == prob.y[i] )            {                correct ++;            }            else if ( label == -1.0 )            {                uncorrect_1 ++;            }            else            {                uncorrect_2 ++;            }            outfile << label << endl;        }#if 1        cout << "total data count: " << nCount << endl <<            "classification correct: " << correct << endl <<            "pos to neg count: " << uncorrect_1 << endl <<            "neg to pos count: " << uncorrect_2 << endl;        cout << "Accuracy: " << static_cast<float>(correct/nCount)             << "(" << correct << "/" << nCount << ")" << endl;#endif        outfile.close();    }}

用例Demo:

// train#include "LibSVMTools.h"void main(){    std::cout <<         "************************************************************" << endl <<        "**          PROGRAM: LibSVM model training.               **" << endl <<         "**                                                        **" << endl <<        "**           Author: Yahui Liu.                           **" << endl <<        "**                   School of Remote Sensing & Inf. Eng. **" << endl <<         "**                   Wuhan University, Hubei, P.R. China  **" << endl <<        "**            Email: yahui.cvrs@gmail.com                 **" << endl <<        "**      Create time: Dec. 1, 2015                         **" << endl <<        "************************************************************" << endl;    string filename = "..\\..\\..\\Data\\heat_scale";    std::string savefielname = "..\\..\\..\\Data\\train.model";    LibSVMTools* libsvm = new LibSVMTools();    libsvm->libSvmTrain(filename, savefielname);    delete libsvm;}/*------------------------------------------------------------------------------------*/// predict#include "LibSVMTools.h"void main(){    std::cout <<         "************************************************************" << endl <<        "**          PROGRAM: LibSVM predict.                      **" << endl <<         "**                                                        **" << endl <<        "**           Author: Yahui Liu.                           **" << endl <<        "**                   School of Remote Sensing & Inf. Eng. **" << endl <<         "**                   Wuhan University, Hubei, P.R. China  **" << endl <<        "**            Email: yahui.cvrs@gmail.com                 **" << endl <<        "**      Create time: Dec. 1, 2015                         **" << endl <<        "************************************************************" << endl;    std::string featureFile = "..\\..\\..\\Data\\heart_scale";    std::string modelFile = "..\\..\\..\\Data\\train.model";    std::string savePredictFile = "..\\..\\..\\Data\\predict.out";    LibSVMTools* libsvm = new LibSVMTools();    libsvm->libSvmPredict(featureFile, modelFile, savePredictFile);    delete libsvm;}
0 0
原创粉丝点击