SMO的C++实现
来源:互联网 发布:29岁了才学淘宝美工 编辑:程序博客网 时间:2024/06/16 11:29
看SMO的论文已经有些时间了,一直想把它实现了,期间搜集了很多资料,可以跟大家分享。关于svm和smo,我就不想写东西了,因为想看这篇博客的肯定都了解了。
我把很多资源都放在下面的参考文献中,有些论文上传到CSDN了,0积分,大家想要的可以自己下载啊。
我这里主要是贴代码,由于我接触SVM的时候用的是opencv的SVM,所以那个简单易用的SVM接口给我留下了很深刻的印象,所以我觉得实现的时候就应该“山寨”这些简单易用的接口,让别人调用的时候很方便。所以对其进行了深度的封装。代码的第一部分是smo.h主要是定义smo这个类。第二部是smo.cpp,该类的实现。第三部分代码是main函数,简单展示如何使用这个类。代码有几个函数不是原创的,因为这些东西原始论文中是没有的,我是在其他论文或者代码上看到的,但是大部分的函数都是原创的。代码中的注释说文章中,或者公式10等等这些表述,都是指参考文献1中的那篇最经典的论文。用来训练的数据是libsvm的经典heart_scale,点击这里获取。
SMO.h
#include "stdafx.h"#include <vector>using namespace std;#define MAX(a,b) ((a)>(b)?(a):(b))#define MIN(a,b) ((a)<(b)?(a):(b))//训练参数的结构体struct SMOParams{int m_nAllSample;//所有的样本数int m_nTrainNumber;//训练的样本数int m_nDimension; //数据的维数double m_dC;//惩罚参数 double m_dT;//在KKT条件中容忍范围double m_dEps; //限制条件 double m_dTwo_sigma_squared; //RBF核函数中的参数 };class SMO{public:SMO() {}~SMO();void train(const char *inputDataPath,const SMOParams &s); //训练分类器void save(); //将分类器保存下来void load(const char* filePath); //将保存后的分类器装载,进行分类void error_rate(); //计算分类正确率int predict(double *array, int length); //对单个数据进行预测private:bool takeStep( int i1, int i2 ); //优化两个拉格朗日乘子double ui(int i1); //分类输出,对应公式10double kernelRBF(int, int ); //径向基核函数double kernelRBF(int, double* ); //径向基核函数的重载,用于predictiondouble dotProduct(int i1,int i2); //两个训练样本的点积int examineExample(int );void readFile(const char *);int examineFirstChoice(int i1,double E1);int examineNonBound(int );int examineBound(int );void init(const SMOParams &s);void outerLoop(); //论文中的外层循环,在伪代码中是在主函数部分private:int m_nAllSample;//所有的样本数int m_nTrainNumber;//训练的样本数int m_nDimension; //数据的维数double m_dC; //惩罚系数double m_dT; //在KKT条件中容忍范围double m_dEps; //限制条件 double m_dTwo_sigma_squared; //RBF核函数中的参数 double m_dB; //阈值 bool m_bIsLoad; //分类器是否通过加载已有分类器得到。如果是,就不能调用train等函数vector <int> m_vTarget; //类别标签double **m_dAllData; //存放训练与测试样本 vector<double> m_vAlph; //拉格朗日乘子 vector<double> m_vErrorCache; //存放non-bound样本误差 vector<double> m_vDotProductCache; //预存向量的点积以减少计算量 };
smo.cpp
#include "stdafx.h"#include "smo.h"#include <fstream>#include <cmath>#include <cstdlib>#include <iostream>using namespace std;//输出类别,对应公式10double SMO::ui(int k){int i;double s = 0;for(i = 0; i < m_nTrainNumber; i++)if(m_vAlph[i] > 0)s += m_vAlph[i] * m_vTarget[i] * kernelRBF(i,k);s -= m_dB;return s;}/*array:输入数据,length:该数据的维数,也就是数组的长度输出该数据的类别,输出为1和-1*/int SMO::predict(double *array, int length ){if ( length != m_nDimension ) throw new exception("Invalid input data !");double s = 0;for(int i = 0; i < m_nTrainNumber; i++)if(m_vAlph[i] > 0)s += m_vAlph[i] * m_vTarget[i] * kernelRBF(i,array);s -= m_dB;return s>0?1:-1;}//点积,求两个点的内积double SMO::dotProduct(int i1,int i2){double dot = 0;for(int i = 0; i < m_nDimension; i++)dot += m_dAllData[i1][i] * m_dAllData[i2][i];return dot;}//径向基核函数,这个核函数使用的是高斯核函数double SMO::kernelRBF(int i1,int i2 ){double s = dotProduct(i1,i2);s *= -2;s += m_vDotProductCache[i1] + m_vDotProductCache[i2];return exp(-s / m_dTwo_sigma_squared);}//重载径向基核函数,被predict函数调用double SMO::kernelRBF(int i1,double *inputData){double s = 0;double dDotProiputData = 0; //输入数据的点积 double dDoti1 = 0; //i1的点积for ( int i = 0; i < m_nDimension; ++i ){ dDoti1 += m_dAllData[i1][i] * m_dAllData[i1][i]; dDotProiputData += inputData[i] * inputData[i]; s += m_dAllData[i1][i] * inputData[i];}s *= -2;s += dDoti1 + dDotProiputData;return exp(-s / m_dTwo_sigma_squared);}//优化两个拉格朗日系数,参考论文中伪代码bool SMO::takeStep(int i1,int i2 ){if ( i1 == i2 ) return 0;double s = 0;double E1 = 0,E2 = 0;double L = 0, H = 0;double k11,k12,k22;double eta = 0;double a1,a2;double Lobj,Hobj;double alph1 = m_vAlph[i1];double alph2 = m_vAlph[i2];double y1 = m_vTarget[i1];double y2 = m_vTarget[i2];if( m_vErrorCache[i1] > 0 && m_vErrorCache[i1] < m_dC )E1 = m_vErrorCache[i1];elseE1 = ui(i1) - y1;s = y1 * y2; //Compute L, H via equations (13) and (14) if( y1 == y2 ){L = MAX(alph2 + alph1 - m_dC , 0 );H = MIN(alph1 + alph2 , m_dC );}else{L = MAX(alph2 - alph1 , 0 );H = MIN(m_dC , m_dC + alph1 + alph2 );}if ( L == H ) return 0;k11 = kernelRBF(i1,i1);k12 = kernelRBF(i1,i2);k22 = kernelRBF(i2,i2);eta = k11 + k22 - 2*k12; if (eta > 0){a2 = alph2 + y2 * (E1-E2)/eta;if(a2 < L)a2 = L;else if( a2 > H )a2 = H;}else{double f1 = y1*(E1 + m_dB) - alph1*k11 - s*alph2*k12;double f2 = y2*(E2 + m_dB) - s*alph1*k12 - alph2*k22;double L1 = alph1 + s*(alph2 - L);double H1 = alph1 + s*(alph2 - H); Lobj = L1*f1 + L*f2 + (L1*L1*k11 + L*L*k22)/2 + s*L*L1*k12;Hobj = H1*f1 + H*f2 + (H1*H1*k11 + H*H*k22)/2 + s*H*H1*k12; if ( Lobj < Hobj - m_dEps )a2 = L;else if ( Lobj > Hobj + m_dEps )a2 = H;else a2 = alph2;}if ( abs(a2-alph2) < m_dEps*(a2+alph2+m_dEps))return 0;a1 = alph1 + s*(alph2 - a2);//Update threshold to reflect change in Lagrange multipliers double b1 = E1 + y1*(a1-alph1)*k11 + y2*(a2-alph2)*k12 + m_dB;double b2 = E2 + y1*(a1-alph1)*k12 + y2*(a2-alph2)*k22 + m_dB;double delta_b = m_dB;m_dB = (b1 + b2) / 2.0;delta_b = m_dB - delta_b;//Update error cache using new Lagrange multipliersdouble t1 = y1 * (a1 - alph1);double t2 = y2 * (a2 - alph2);for(int i = 0; i < m_nTrainNumber; i++)if(m_vAlph[i] > 0 && m_vAlph[i] < m_dC)m_vErrorCache[i] += t1 * kernelRBF(i1,i) + t2 * (kernelRBF(i2,i)) - delta_b;m_vErrorCache[i1] = 0;m_vErrorCache[i2] = 0;//Store a1,a2 in the alpha array m_vAlph[i1] = a1;m_vAlph[i2] = a2;return 1;}//使用启发式的方法,实现inner loop来选择第二个乘子//这个函数被outer loop调用int SMO::examineExample(int i1){double y1 = m_vTarget[i1];double alph1 = m_vAlph[i1];double E1;if( m_vErrorCache[i1] > 0 && m_vErrorCache[i1] < m_dC )E1 = m_vErrorCache[i1];elseE1 = ui(i1) - y1;double r1 = E1 * y1;if ( (r1 < - m_dT && alph1 < m_dC ) || ( r1 > m_dT && alph1 > 0)){/*使用三种方法选择第二个乘子 1:在non-bound乘子中寻找maximum fabs(E1-E2)的样本 2:如果上面没取得进展,那么从随机位置查找non-boundary 样本 3:如果上面也失败,则从随机位置查找整个样本,改为bound样本 */if(examineFirstChoice(i1,E1)) return 1; //第1种情况 if(examineNonBound(i1)) return 1; //第2种情况 if(examineBound(i1)) return 1; //第3种情况 }return 0;}//1:在non-bound乘子中寻找maximum fabs(E1-E2)的样本 int SMO::examineFirstChoice(int i1,double E1){int k,i2;double tmax;double E2,temp;for(i2 = - 1,tmax = 0,k = 0; k < m_nTrainNumber; k++){if(m_vAlph[k] > 0 && m_vAlph[k] < m_dC){E2 = m_vErrorCache[k];temp = fabs(E1 - E2);if(temp > tmax){tmax = temp;i2 = k;}}}if(i2 >= 0 && takeStep(i1,i2)) return 1;return 0;}//2:如果上面没取得进展,那么从随机位置查找non-boundary样本 int SMO::examineNonBound(int i1){int k0 = rand() % m_nTrainNumber;int k,i2;for(k = 0; k < m_nTrainNumber; k++){i2 = (k + k0) % m_nTrainNumber;if((m_vAlph[i2] > 0 && m_vAlph[i2] < m_dC) && takeStep(i1,i2)) return 1;}return 0;}// 3:如果上面也失败,则从随机位置查找整个样本,(改为bound样本) int SMO::examineBound(int i1){int k0 = rand() % m_nTrainNumber;int k,i2;for(k = 0; k < m_nTrainNumber; k++){i2 = (k + k0) % m_nTrainNumber;if(takeStep(i1,i2)) return 1;}return 0;}/**********************************inputDatapath:训练数据保存的路径s: SMO的参数结构体***********************************/void SMO::train(const char *inputDataPath,const SMOParams &s){if (m_bIsLoad){cerr<<"分类器已经得到,请勿再训练!"<<endl;return;}init(s);readFile(inputDataPath);//设置预计算点积(对训练样本的设置,对于测试样本也要考虑) for(int i = 0; i < m_nAllSample; i++) m_vDotProductCache[i] = dotProduct(i,i);outerLoop(); }//这里使用的是libsvm中的经典的heart_scal数据集中的格式void SMO::readFile(const char* filePath){ifstream f(filePath);if(!f){cerr<<"训练数据读入失败!"<<endl;exit(1);}int i = 0,j = 0;int k;int num; //数据编号的读取char ch; //数据中的‘:’号读取int count = 0;while(f>>m_vTarget[i]){//if ( i == 270 ) break;count++;for(k = 1; k <= m_nDimension; k++){f>>num>>ch;f>>m_dAllData[i][num-1];if ( num == m_nDimension ) break;j++;}i++; if ( i >= m_nAllSample ) break;j = 0;}}//计算分类误差率 void SMO::error_rate(){int ac = 0;double accuracy,tar;cout<<"训练终于结束鸟"<<endl;for(int i = m_nTrainNumber; i < m_nAllSample; i++){tar = ui(i);if(tar > 0 && m_vTarget[i] > 0 || tar < 0 && m_vTarget[i] < 0) ac++;//cout<<"The "<<i - train_num + 1<<"th test value is "<<tar<<endl;}accuracy = (double)ac / (m_nAllSample - m_nTrainNumber);cout<<"精确度:"<<accuracy * 100<<"%"<<endl;}//初始化SMO,除了使用参数结构体以外,还要将各个//vector预设大小,提高程序的运行效率void SMO::init(const SMOParams &s){m_nAllSample = s.m_nAllSample;//所有的样本数m_nTrainNumber = s.m_nTrainNumber;//训练的样本数m_nDimension = s.m_nDimension; //数据的维数m_dC = s.m_dC;//惩罚参数 m_dT = s.m_dT;//在KKT条件中容忍范围m_dEps = s.m_dEps; //限制条件 m_dTwo_sigma_squared = s.m_dTwo_sigma_squared; //RBF核函数中的参数 m_bIsLoad = false;m_dB = 0.0;//初始化二维数组m_dAllData = new double*[m_nAllSample];for ( int i = 0; i < m_nAllSample; ++i )m_dAllData[i] = new double[m_nDimension]; for ( int i = 0; i < m_nAllSample; ++i )for (int j = 0; j < m_nDimension; ++j )m_dAllData[i][j] = 0.0;m_vTarget.resize(m_nAllSample,0);m_vAlph.resize(m_nTrainNumber,0);m_vErrorCache.resize(m_nTrainNumber,0);m_vDotProductCache.resize(m_nAllSample,0);}//对应论文中的outer loop,用来寻找第一个要优化的乘子void SMO::outerLoop(){int numChanged = 0;bool examineAll = 1;while ( numChanged > 0 || examineAll ){numChanged = 0;if ( examineAll ){for ( int i = 0; i < m_nTrainNumber; ++i )numChanged += examineExample(i);}else{for ( int i = 0; i < m_nTrainNumber; ++i ){if (m_vAlph[i] > 0 && m_vAlph[i] < m_dC )numChanged += examineExample(i);}}if ( examineAll == 1 )examineAll = 0;else if ( numChanged == 0 )examineAll = 1;}}//将支持向量及相关必要的信息保存下来void SMO::save(){ofstream outfile("svm.txt");int countVec = 0; //支持向量的个数for ( int i = 0; i < m_nTrainNumber; ++i )if ( m_vAlph[i] > 0 )++countVec;//第一行保存支持向量的个数,数据的维数,还有高斯核参数,阈值boutfile<<countVec<<' '<<m_nDimension<<' '<<m_dTwo_sigma_squared<<' '<<m_dB<<'\n';for ( int i = 0; i < m_nTrainNumber; ++i ){if ( m_vAlph[i] > 0 ){outfile<<m_vTarget[i]<<' '<<m_vAlph[i]<<' ';for ( int j = 0; j < m_nDimension; ++j )outfile<<m_dAllData[i][j]<<' ';outfile<<'\n';}}}//将保存的训练结果(即分类器)加载进来,用于分类void SMO::load(const char *filePath ){ifstream infile(filePath);if(!infile){cerr<<"分类器读入失败!"<<endl;exit(1);}//注明该分类器是外部加载得到,而不是训练得到m_bIsLoad = true;infile>>m_nAllSample>>m_nDimension>>m_dTwo_sigma_squared>>m_dB;m_nTrainNumber = m_nAllSample; //为了使用predict函数而设置//初始化二维数组m_dAllData = new double*[m_nAllSample];for ( int i = 0; i < m_nAllSample; ++i )m_dAllData[i] = new double[m_nDimension];for ( int i = 0; i < m_nAllSample; ++i )for (int j = 0; j < m_nDimension; ++j )m_dAllData[i][j] = 0.0;m_vTarget.resize(m_nAllSample,0);m_vAlph.resize(m_nAllSample,0);int i = 0;while(infile>>m_vTarget[i]){infile>>m_vAlph[i];for (int j = 0; j < m_nDimension; ++j )infile>>m_dAllData[i][j];++i;if ( i == m_nAllSample ) break;}}SMO::~SMO(){for (int i = 0; i < m_nAllSample; ++i ){delete [] m_dAllData[i];}delete []m_dAllData;}
#include "stdafx.h"#include "smo.h"#include <iostream>using namespace std;int _tmain(int argc, _TCHAR* argv[]){SMOParams params;params.m_nAllSample = 270;//所有的样本数params.m_nTrainNumber = 200;//训练的样本数params.m_nDimension = 13; //数据的维数params.m_dC = 1.0;//惩罚参数 params.m_dT = 0.001;//在KKT条件中容忍范围params.m_dEps = 1.0E-12; //限制条件 params.m_dTwo_sigma_squared = 2.0; //RBF核函数中的参数 SMO tmp;//tmp.train("heart_scale.txt",params);//tmp.error_rate();//tmp.save();tmp.load("svm.txt");//这个数据是heart_scale的第17行,对应的输出应该是+1double mydata[] = {-0.291667, 1, 1, -0.132075, -0.155251, -1, -1, -0.251908, 1, -0.419355, 0, 0.333333, 1};//这个数据是heart_scale的第264行,对应的输出应该是-1double mydata2[] = { -0.166667, 1, -0.333333, -0.320755, -0.360731, -1, -1, 0.526718, -1, -0.806452, -1, -1, -1,};cout<<"预测的类别是:"<< tmp.predict(mydata,13)<<endl;cout<<"真实的类别是: 1"<<endl;cout<<endl;cout<<"预测的类别是:"<< tmp.predict(mydata2,13)<<endl;cout<<"真实的类别是: -1"<<endl;cout<<endl;return 0;}
参考文献:
http://download.csdn.net/detail/gningh/5959555
http://download.csdn.net/detail/gningh/5959541
http://download.csdn.net/detail/gningh/5959495
http://www.cnblogs.com/vivounicorn/archive/2011/06/01/2067496.html
http://www.cnblogs.com/jerrylead/archive/2011/03/18/1988419.html
http://blog.csdn.net/techq/article/details/6171688
http://blog.csdn.net/keith0812/article/details/9129363
还有个是伯克利大学的C语言的版本,我没怎么看,推荐给喜欢用C的同学。
http://www.cs.berkeley.edu/~richie/stat242b/hw3/smo/smo.c
- SMO的C++实现
- 支持向量机SVM的SMO方法实现(C++)
- smo算法的c++实现
- 基于matlab的SMO实现
- SMO实现
- SVM 的实现之SMO算法
- Python实现SMO算法
- 统计学习方法:基于SMO算法的SVM的Python实现
- 基于SMO方法的支持向量机Pascal代码实现
- 用SMO算法实现了SVM的感悟
- SVM支持向量机(SMO算法)的R实现
- SVM中SMO算法的实现理论+代码
- SMO的MSDN文档
- svm的smo算法
- 改进的SMO算法
- SVM-SMO算法C++实现
- SMO优化算法(实现SVM)
- SVM实现之SMO算法
- Timesten使用直连的方法jdbc连接实例
- 二、OpenGL 的描述
- ZOJ 3195 Design the city LCA转RMQ
- 给设计师提供的十大无代码网站编辑器
- Spring中的事务管理,hibernate整合,struts整合(佟刚)
- SMO的C++实现
- 关于document.getElmentByName("name").value的值是undifind
- IE6不支持透明背景png图片
- C结构体之-位域
- Android SSL BKS证书的生成过程
- Log文件太大,手机ROM空间被占满
- 使用可恢复空间分配
- 教育办公系统的环境搭建,登陆的JavaScript验证,简单验证,复杂验证
- Git分支管理策略