LibSVM C/C++
来源:互联网 发布:淘宝上哪家店牛仔裤好 编辑:程序博客网 时间:2024/06/10 18:10
本系列文章由 @YhL_Leo 出品,转载请注明出处。
文章链接: http://blog.csdn.net/yhl_leo/article/details/50179779
在LibSVM
的库的svm.h
头文件中定义了四个主要结构体:
1 训练模型的结构体
struct svm_problem{ int l; // total number of samples double *y; // label of each sample struct svm_node **x; // feature vector of each sample};
样本的类别通常使用+1
与-1
进行标识。如果样本的类别,则分类的准确率也就无法计算。
2 数据节点的结构体
struct svm_node{ int index; double value;};
数据组织结构如图1所示:
3 模型参数结构体
struct svm_parameter{ int svm_type; int kernel_type; int degree; /* for poly */ double gamma; /* for poly/rbf/sigmoid */ double coef0; /* for poly/sigmoid */ /* these are for training only */ double cache_size; /* in MB */ double eps; /* stopping criteria */ double C; /* for C_SVC, EPSILON_SVR and NU_SVR */ int nr_weight; /* for C_SVC */ int *weight_label; /* for C_SVC */ double* weight; /* for C_SVC */ double nu; /* for NU_SVC, ONE_CLASS, and NU_SVR */ double p; /* for EPSILON_SVR */ int shrinking; /* use the shrinking heuristics */ int probability; /* do probability estimates */};
其中,各个参数的含义为:
-s svm_type : set type of SVM (default 0) 0 -- C-SVC 1 -- nu-SVC 2 -- one-class SVM 3 -- epsilon-SVR 4 -- nu-SVR-t kernel_type : set type of kernel function (default 2) 0 -- linear: u'*v 1 -- polynomial: (gamma*u'*v + coef0)^degree 2 -- radial basis function: exp(-gamma*|u-v|^2) 3 -- sigmoid: tanh(gamma*u'*v + coef0)-d degree : set degree in kernel function (default 3)-g gamma : set gamma in kernel function (default 1/num_features)-r coef0 : set coef0 in kernel function (default 0)-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)-m cachesize : set cache memory size in MB (default 100)-e epsilon : set tolerance of termination criterion (default 0.001)-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)-b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)-wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)
SVM模型类型和核函数类型:
enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR }; /* svm_type */enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */
4 训练输出模型结构体
struct svm_model{ struct svm_parameter param; /* parameter */ int nr_class; /* number of classes, = 2 in regression/one class svm */ int l; /* total #SV */ struct svm_node **SV; /* SVs (SV[l]) */ double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */ double *rho; /* constants in decision functions (rho[k*(k-1)/2]) */ double *probA; /* pariwise probability information */ double *probB; int *sv_indices; /* sv_indices[0,...,nSV-1] are values in [1,...,num_traning_data] to indicate SVs in the training set */ /* for classification only */ int *label; /* label of each class (label[k]) */ int *nSV; /* number of SVs for each class (nSV[k]) */ /* nSV[0] + nSV[1] + ... + nSV[k-1] = l */ /* XXX */ int free_sv; /* 1 if svm_model is created by svm_load_model*/ /* 0 if svm_model is created by svm_train */};
5 使用方法
以LibSVM
提供的样本特征集heart_scale
为例,首先需要读取样本特征数据,可以利用svm-train.c
文件中的read_problem
函数,为了方便使用,对其进行了重写改写:
// TrainingDataLoad.h/* Load training data from svm format file. - Editor: Yahui Liu. - Data: 2015-11-30 - Email: yahui.cvrs@gmail.com - Address: Computer Vision and Remote Sensing(CVRS), Lab.**/#ifndef TRAINING_DATA_LOAD_H#define TRAINING_DATA_LOAD_H#pragma once#include <stdio.h>#include <stdlib.h>#include <ctype.h>#include <iostream>#include <vector>#include <string>#include <fstream>#include <errno.h>#include "svm.h"//#include "svm-scale.c"using namespace std;#define MAX_LINE_LEN 1024class TrainingDateLoad{public: TrainingDateLoad() { line = NULL; } ~TrainingDateLoad() { line = NULL; }public: char* line;// public:// static struct svm_parameter _paramInit;public: /*! load svm model */ void loadModel( std::string filename, struct svm_model*& model); /*! skip the target */ void svmSkipTarget( char*& p); /* skip the element */ void svmSkipElement( char*& p); void initialParams( struct svm_parameter& param ); /*! load training data */ void readProblem( std::string filename, struct svm_problem& prob, struct svm_parameter& param ); char* readline(FILE *input); void exit_input_error(int line_num) { cout << "Wrong input format at line: " << line_num << endl; exit(1); }};#endif // TRAINING_DATA_LOAD_H
// TrainingDataLoad.cpp#include "TrainingDataLoad.h"void TrainingDateLoad::loadModel(std::string filename, struct svm_model*& model){ model = svm_load_model(filename.c_str());}void TrainingDateLoad::svmSkipTarget(char*& p){ while(isspace(*p)) ++p; while(!isspace(*p)) ++p;}void TrainingDateLoad::svmSkipElement(char*& p){ while(*p!=':') ++p; ++p; while(isspace(*p)) ++p; while(*p && !isspace(*p)) ++p;}void TrainingDateLoad::initialParams( struct svm_parameter& param ){ // default values param.svm_type = C_SVC; param.kernel_type = RBF; param.degree = 3; param.gamma = 0; // 1/num_features param.coef0 = 0; param.nu = 0.5; param.cache_size = 100; param.C = 1; param.eps = 1e-3; param.p = 0.1; param.shrinking = 1; param.probability = 0; param.nr_weight = 0; param.weight_label = NULL; param.weight = NULL;}void TrainingDateLoad::readProblem( std::string filename, struct svm_problem& prob, struct svm_parameter& param ){ int max_index, inst_max_index, i; size_t elements, j; FILE *fp = fopen(filename.c_str(),"r"); char *endptr; char *idx, *val, *label; if(fp == NULL) { fprintf(stderr,"can't open input file %s\n",filename); exit(1); } prob.l = 0; elements = 0; line = new char[MAX_LINE_LEN]; while(readline(fp)!=NULL) { char *p = strtok(line," \t"); // label // features while(1) { p = strtok(NULL," \t"); if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature break; ++elements; } ++elements; ++prob.l; } rewind(fp); prob.y = new double[prob.l]; prob.x = new struct svm_node *[prob.l]; struct svm_node *x_space = new struct svm_node[elements]; max_index = 0; j=0; for(i=0;i<prob.l;i++) { inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0 readline(fp); prob.x[i] = &x_space[j]; label = strtok(line," \t\n"); if(label == NULL) // empty line exit_input_error(i+1); prob.y[i] = strtod(label,&endptr); if(endptr == label || *endptr != '\0') exit_input_error(i+1); while(1) { idx = strtok(NULL,":"); val = strtok(NULL," \t"); if(val == NULL) break; errno = 0; x_space[j].index = (int) strtol(idx,&endptr,10); if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index) exit_input_error(i+1); else inst_max_index = x_space[j].index; errno = 0; x_space[j].value = strtod(val,&endptr); if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr))) exit_input_error(i+1); ++j; } if(inst_max_index > max_index) max_index = inst_max_index; x_space[j++].index = -1; } if(param.gamma == 0 && max_index > 0) param.gamma = 1.0/max_index; if(param.kernel_type == PRECOMPUTED) for(i=0;i<prob.l;i++) { if (prob.x[i][0].index != 0) { fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n"); exit(1); } if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index) { fprintf(stderr,"Wrong input format: sample_serial_number out of range\n"); exit(1); } } fclose(fp);}char* TrainingDateLoad::readline(FILE *input){ int len; if(fgets(line,MAX_LINE_LEN,input) == NULL) return NULL; int max_line_len = MAX_LINE_LEN; while(strrchr(line,'\n') == NULL) { max_line_len *= 2; line = (char *) realloc(line,max_line_len); len = (int) strlen(line); if(fgets(line+len,max_line_len-len,input) == NULL) break; } return line;}
将样本训练与预测进行改写:
// LibSVMTools.h/* LibSVM train and predict tools. - Editor: Yahui Liu. - Data: 2015-12-3 - Email: yahui.cvrs@gmail.com - Address: Computer Vision and Remote Sensing(CVRS), Lab.**/#ifndef LIBSVM_TOOL_H#define LIBSVM_TOOL_H#pragma once#include <iostream>#include <string>#include "svm.h"#include "TrainingDataLoad.h"class LibSVMTools{public: LibSVMTools(){} ~LibSVMTools(){}public: /*! - featureFile: features of images saved in libsvm format. - saveModelFile: save the trained model file. **/ void libSvmTrain(std::string featureFile, std::string saveModelFile); /*! - featureFile: features of images saved in libsvm format. - modelFile: libsvm trained model. - savePredictFile: save the predicting results. **/ void libSvmPredict(std::string featureFile, std::string modelFile, std::string savePredictFile);};#endif // LIBSVM_TOOL_H
// LibSVMTools.cpp#include "LibSVMTools.h"void LibSVMTools::libSvmTrain(std::string featureFile, std::string saveModelFile){ struct svm_parameter param; struct svm_problem prob; TrainingDateLoad* trainData = new TrainingDateLoad; trainData->initialParams( param ); trainData->readProblem(featureFile, prob, param); const char*errorMsg = svm_check_parameter(&prob, ¶m); if ( errorMsg ) { cout << errorMsg << endl; return; } struct svm_model *model = svm_train(&prob, ¶m);#if 1 cout << "svm_type: " << model->param.svm_type << endl << "kernel_type: " << model->param.kernel_type << endl << "gamma: " << model->param.gamma << endl << "nr_class: " << model->nr_class << endl << "total_sv: " << model->l << endl << "rho: " << model->rho[0] << endl << "label: " << model->label[0] << " " << model->label[1] << endl << "nr_sv: " << model->nSV[0] << " " << model->nSV[1] << endl;#endif int saveModel = svm_save_model( saveModelFile.c_str(), model );}void LibSVMTools::libSvmPredict(std::string featureFile, std::string modelFile, std::string savePredictFile){ struct svm_parameter param; struct svm_problem prob; TrainingDateLoad * trainData = new TrainingDateLoad; trainData->initialParams( param ); trainData->readProblem(featureFile, prob, param); struct svm_model* model; trainData->loadModel(modelFile.c_str(), model); float correct(0.0); // all correct float uncorrect_1(0.0); // pos to neg float uncorrect_2(0.0); // neg to pos if ( prob.l ) { const int nCount = prob.l;; ofstream outfile( savePredictFile, ios::out ); for( int i=0; i<nCount; i++ ) { double label = svm_predict(model, prob.x[i]); if ( label == prob.y[i] ) { correct ++; } else if ( label == -1.0 ) { uncorrect_1 ++; } else { uncorrect_2 ++; } outfile << label << endl; }#if 1 cout << "total data count: " << nCount << endl << "classification correct: " << correct << endl << "pos to neg count: " << uncorrect_1 << endl << "neg to pos count: " << uncorrect_2 << endl; cout << "Accuracy: " << static_cast<float>(correct/nCount) << "(" << correct << "/" << nCount << ")" << endl;#endif outfile.close(); }}
用例Demo:
// train#include "LibSVMTools.h"void main(){ std::cout << "************************************************************" << endl << "** PROGRAM: LibSVM model training. **" << endl << "** **" << endl << "** Author: Yahui Liu. **" << endl << "** School of Remote Sensing & Inf. Eng. **" << endl << "** Wuhan University, Hubei, P.R. China **" << endl << "** Email: yahui.cvrs@gmail.com **" << endl << "** Create time: Dec. 1, 2015 **" << endl << "************************************************************" << endl; string filename = "..\\..\\..\\Data\\heat_scale"; std::string savefielname = "..\\..\\..\\Data\\train.model"; LibSVMTools* libsvm = new LibSVMTools(); libsvm->libSvmTrain(filename, savefielname); delete libsvm;}/*------------------------------------------------------------------------------------*/// predict#include "LibSVMTools.h"void main(){ std::cout << "************************************************************" << endl << "** PROGRAM: LibSVM predict. **" << endl << "** **" << endl << "** Author: Yahui Liu. **" << endl << "** School of Remote Sensing & Inf. Eng. **" << endl << "** Wuhan University, Hubei, P.R. China **" << endl << "** Email: yahui.cvrs@gmail.com **" << endl << "** Create time: Dec. 1, 2015 **" << endl << "************************************************************" << endl; std::string featureFile = "..\\..\\..\\Data\\heart_scale"; std::string modelFile = "..\\..\\..\\Data\\train.model"; std::string savePredictFile = "..\\..\\..\\Data\\predict.out"; LibSVMTools* libsvm = new LibSVMTools(); libsvm->libSvmPredict(featureFile, modelFile, savePredictFile); delete libsvm;}
0 0
- LibSVM C/C++
- libsvm移植到c/c++中
- libsvm-svm-scale.c 源码分析
- libsvm数据格式、c语言输出符合libsvm要求格式的特征文件代码
- LibSVM笔记系列(3)——初学移植libsvm的C/C++版本
- libsvm 线性核 C-SVM 参数寻优
- [matlab-libsvm] 关于SVM参数c&g选取程序
- LibSVM 3.12的源码分析Svm-train.c
- libsvm 使用python交叉验证 取最优参数 c g
- libsvm中参数c与g的调整
- CSV转LibSVM格式之C语言实现改进
- LIBSVM学习(六)代码结构及c-SVC过程
- 在libsvm中如何求最佳参数c和gamma
- [matlab-libsvm] 关于SVM参数c&g选取程序
- 关于SVM参数c&g选取的总结帖[matlab-libsvm]
- libsvm ——SVM中参数 c和g的最佳值的选择
- 关于libsvm的Java和C版本的运算结果不一致的问题
- 关于SVM参数c&g选取的总结帖[matlab-libsvm]
- 星空幻想
- 解压版的Tomcat基本配置和安装
- 《电子或通信领域当前的主流技术及其社会需求调查报告》
- hdu1176
- python基础教程共60课-第3课IDE
- LibSVM C/C++
- 53,类方法
- project euler 37
- E-栈--括号匹配
- POJ 2352 Stars (区间建树,单点更新)
- project euler 38
- 折腾了一个晚上,终于发表了第一篇学习技术博客,
- 终端卸载Ubuntu软件
- light--oj--1116--Ekka Dokka(数学问题)