SVM算法的一点点理解
来源:互联网 发布:精准扶贫数据平台登录 编辑:程序博客网 时间:2024/05/19 19:57
最近几天看了下SVM算法,下面是我的个人理解。
SVM(支持向量机)是为了找一个超平面,使得几何间隔最大化,为了方便讨论,下面只考虑二维的情况,并且数据线性可分且数据的种类只有两种(即二分类问题),这样问题就是找一个直线使得这条直线能够把正负两类点分割,并且使得所有数据的最小几何间隔最大。
这里假设数据为(x1,y1),(x2,y2)…….(xn,yn)
可以把数据表示为坐标轴上的点。
我们定义函数间隔为yi*(wxi+b)
几何间隔为yi*(wxi+b)/||w||
我们定义最小的函数间隔为T1 = min(所以数据的函数间隔)
最小的几个间隔为T2 = min(所以数据的几何间隔)
那么SVM算法求的就是
max T2
st yi*(wxi+b)/||w||>=T2
可以写成
max T1/||w||
st yi*(wxi+b)/||w||>=T2/||w||
这里我们注意到函数间隔这个约束是可以任意改变的,比如缩小或者放大N倍,这里可以把T1置为1
则 max 1/||w||
st yi*(wxi+b)>=1
即 min ·1/2||w||^2
st yi*(wxi+b)>=1
运用朗格朗日
L(a,w,b) = 1/2||w||^2-sigma(ai(yi*(wxi+b)-1))
对w和b求偏导数带入得到对偶形势
下面我们来讨论下数据线性不可分的情况。
当数据在二维线性不可分的情况下映射到三维,或者更高维可能就线性可分了。
下面的两种核函数
还有一种处理线性不可分的方法-软间隔
我们允许分类器出现误差。
即 min 1/2||w||+C*sigma(mi)
st yi*(w*xi)>=1-mi
可以理解为给分类器一个犯错的允许,但是每个误差都是要付出代价的
同样运用拉格朗日可以得到对偶形式
其中
如何求得这个问题的解是NP难问题
但是微软研究院的以为大神的SMO算法可以逼近最优
这里简单讲讲SMO算法
我们选取a1,a2作为要更新的对象,那么根据限制条件sigma(ai*yi) = 0,和0<=ai<=c可以得到a1 = (s-a2*y2)*y1;并且得到a2的取值范围,带入原式就是个一元问题。幸运的是原函数是个凸函数,可以直接求极值。
这里a1,a2的选取可以运用启发式搜素,这里就不说了。
这里是代码。#include "stdafx.h"#include <opencv2/core/core.hpp>#include <opencv2/imgproc/imgproc.hpp>#include <opencv2/highgui/highgui.hpp>#include <iostream>#include <algorithm>#include <stdlib.h>#include <math.h>#include <time.h>#include <conio.h>#include <math.h>#include <dirent.h> #include <vector>using namespace cv;using namespace std;#define X_LEN 32*30#define S_LEN 64#define TOT 64#define C 1#define EPS 1e-4#define W 3#define MAXN 80int x[S_LEN+1][X_LEN],y[S_LEN];double a[S_LEN],miss[S_LEN],b;double w[X_LEN],res[S_LEN];vector<string> ve[2];double pow2(double x){ return x*x;}int eq(double a){ if(fabs(a)<EPS) return 1; else return 0;}double kernel(int s1,int s2){ /*int sum = 0; for(int i = 0;i<X_LEN;i++) { sum+=x[s1][i]*x[s2][i]; } return sum;*/ double sum = 0; for(int i = 0;i<X_LEN;i++) { sum+=pow2(x[s1][i]-x[s2][i]); } return exp(-sum/2*W*W);}void cal_w(){ for(int i = 0;i<X_LEN;i++) { double sum = 0; for(int j = 0;j<S_LEN;j++) { sum+=a[j]*y[j]*x[j][i]; } w[i] = sum; }}double cal_res(int s){ double sum = 0; for(int i = 0;i<S_LEN;i++) { sum+=a[i]*y[i]*kernel(i,s); } return sum+b;}double cal_miss(int s){ return cal_res(s)-y[s];}int choose1(){/* for(int i=0;i<S_LEN;i++) { if(a[i]>0&&a[i]<C&&!eq(y[i]*cal_res(i)-1)) return i; } for(int i=0;i<S_LEN;i++) { if(eq(a[i])&&y[i]*cal_res(i)-1<1-EPS) return i; if(eq(a[i]-C)&&y[i]*cal_res(i)-1>1-EPS) return i; }*/ return rand()%S_LEN;}int choose2(int index1){ /*int index2 = -1; double maxnum = -1; double m = cal_miss(index1); for(int i = 0;i<S_LEN;i++) { if(i==index1) continue; if(fabs(m-cal_miss(i))>maxnum) maxnum = fabs(m-cal_miss(i)),index2 = i; } */ //return index2; int temp = rand()%S_LEN; while(temp==index1) temp = rand()%S_LEN; return temp;}void svm(){ int index1 = choose1(); int index2 = choose2(index1); int i,j,k; //printf("123ssdfsfsdf1231\n"); double sum = 0; double L,H; for(i=0;i<S_LEN;i++) { if(i!=index1&&i!=index2) { sum+=a[i]*y[i]; } } sum*=-1;// printf("%lf\n",sum); if(y[index1]==y[index2]) { if(y[index1]==-1) { H = min(C,(a[index1]+a[index2])); L = max(0,(a[index1]+a[index2])-C); } else { H = min(C,a[index1]+a[index2]); L = max(0,a[index1]+a[index2]-C); } } else { if(y[index2]==-1) { swap(index1,index2); } H = min(C,C+a[index2]-a[index1]); L = max(0,a[index2]-a[index1]); } //printf("%lf %lf\n",L,H); for(i=0;i<S_LEN;i++) { miss[i] = cal_miss(i); } double temp = kernel(index1,index1)+kernel(index2,index2)-2*kernel(index1,index2); double temp_a = a[index2]+y[index2]*(miss[index1]-miss[index2])/temp; //printf("%lf %lf %lf %d %d %lf\n",a[index2],miss[index1],miss[index2],index1,index2,temp_a); if(temp_a>=L&&temp_a<=H) a[index2] = temp_a; else { double v1 = cal_res(index1)-b-a[index1]*y[index1]*kernel(index1,index1)-a[index2]*y[index2]*kernel(index1,index2); double v2 = cal_res(index2)-b-a[index1]*y[index1]*kernel(index2,index1)-a[index2]*y[index2]*kernel(index2,index2); double s1 = 0.5*kernel(index1,index1)*pow2(sum-L*y[index2])+0.5*kernel(index2,index2)*pow2(L) +y[index2]*kernel(index1,index2)*(sum-L*y[index2])*L-(sum-L*y[index2])*y[index1]- L+v1*(sum-L*y[index2])+y[index2]*v2*L; double s2 = 0.5*kernel(index1,index1)*pow2(sum-H*y[index2])+0.5*kernel(index2,index2)*pow2(H) +y[index2]*kernel(index1,index2)*(sum-H*y[index2])*H-(sum-H*y[index2])*y[index1]- H+v1*(sum-H*y[index2])+y[index2]*v2*H; if(s1<s2) a[index2] = L; else a[index2] = H; } a[index1] = (sum-y[index2]*a[index2])*y[index1]; //printf("%d %d\n",index1,index2); if(a[index1]>0&&a[index1]<C) { //printf("11111111111\n"); b = y[index1]-cal_res(index1)+b; } else if(a[index2]>0&&a[index2]<C) { // printf("222222222222\n"); b = y[index2]-cal_res(index2)+b; } else { // printf("33333333333333333\n"); b = (y[index1]-cal_res(index1)+2*b+y[index2]-cal_res(index2))/2; } // printf("%lf\n",a[index2]);}int check(double &pre){ double sum = 0; for(int i = 0;i<S_LEN;i++) { for(int j = 0;j<S_LEN;j++) { sum+=a[i]*a[j]*y[i]*y[j]*kernel(i,j); } } sum*=0.5; for(int i = 0;i<S_LEN;i++) { sum-=a[i]; } //printf("%lf %lf\n",pre,sum); //printf("1111111,%lf\n",pre-sum); if(pre-sum>EPS) { pre = sum; return 1; } else { pre = sum; return 0; }}void readdir(){ DIR *directory_pointer; directory_pointer = opendir("d://ccut//faces_4//an2i"); struct dirent *entry; int i; while((entry = readdir(directory_pointer))!=NULL) { string s = "d://ccut//faces_4//an2i//"; if((*entry).d_name[0]=='.') continue; s+=(*entry).d_name; ve[0].push_back(s); // cout<<s<<endl; //char name[MAXN]; /* for(i=0;i<s.length();i++) { name[i] = s[i]; } name[s.length()] = 0; */ /* DIR * directory; directory = opendir(name); struct dirent *entry_tmp; string temp = s; while((entry_tmp = readdir(directory))!=NULL) { s = temp; if((entry_tmp->d_name)[0]=='.') continue; s+="//"; s+=entry_tmp->d_name; path[p_flag] = s; path_name[p_flag++]+= entry_tmp->d_name; } closedir(directory); */ } directory_pointer = opendir("d://ccut//faces_4//at33"); while((entry = readdir(directory_pointer))!=NULL) { string s = "d://ccut//faces_4//at33//"; if((*entry).d_name[0]=='.') continue; s+=(*entry).d_name; ve[1].push_back(s); } closedir(directory_pointer);}void init(){ //printf("%d %d\n",ve[0].size(),ve[1].size()); //printf("1111111\n"); readdir(); //printf("1111111"); b = 0;// printf("%d %d\n",ve[0].size(),ve[1].size()); for(int i = 0;i<S_LEN;i++) { a[i] = 0; } Mat mat; int cnt = 0; for(int i = 0;i<ve[0].size();i++) { y[i] = 1; mat= imread(ve[0][i]); cnt = 0; for(int j = 0;j<mat.rows;j++) { for(int k = 0;k<mat.cols;k++) { x[i][cnt++] = (int)(mat.at<uchar>(j,k)); } } } for(int i = 0;i<ve[1].size();i++) { cnt = 0; y[32+i] = -1; mat= imread(ve[1][i]); for(int j = 0;j<mat.rows;j++) { for(int k = 0;k<mat.cols;k++) { x[32+i][cnt++] = (int)(mat.at<uchar>(j,k)); } } } //printf("%d\n",cnt); /*mat[0] = imread("D:\\ccut\\faces_4\\an2i\\an2i_left_angry_open_4.pgm"); mat[1] = imread("D:\\ccut\\faces_4\\an2i\\an2i_right_angry_open_4.pgm"); mat[2] = imread("D:\\ccut\\faces_4\\an2i\\an2i_left_angry_sunglasses_4.pgm"); y[0] = 1; y[1] = -1; y[2] = 1; for(int i = 0;i<S_LEN;i++) { int cnt = 0; for(int j = 0;j<mat[i].rows;j++) { for(int k = 0;k<mat[i].cols;k++) { x[i][cnt++] = (int)(mat[i].at<uchar>(i,j)); } } } /* //printf("%d %d\n",mat1.rows,mat1.cols); x[0][0] = 1,x[0][1] = 1,y[0] = 1; x[1][0] = 1,x[1][1] = 2,y[1] = -1; x[2][0] = 1,x[2][1] = 3,y[2] = 1; x[3][0] = 1,x[3][1] = 4,y[3] = -1; */}int _tmain(int argc, _TCHAR* argv[]){ srand(time(0)); init(); double pre = 0; int flag = 0; while(flag<=200) { double s = pre; svm(); cal_w(); /*for(int i=0;i<S_LEN;i++) { printf("%lf ",a[i]);; } printf("\n"); for(int i = 0;i<X_LEN;i++) { printf("%lf ",w[i]); } printf("%lf\n",b); printf("%lf\n",pre); */ flag+=!check(pre); printf("%d\n",flag); } //FILE *f = fopen("D:\\ccut\\faces_4\\an2i\\an2i_left_angry_open_4.pgm","rb"); //Mat mat1 = imread("D:\\ccut\\faces_4\\an2i\\an2i_left_angry_open_4.pgm"); //waitKey(111111); for(int i = 0;i<=63;i++) { printf("%lf\n",cal_res(i)); } //printf("%lf\n",cal_res(63)); return 0;}
- SVM算法的一点点理解
- SVM算法的理解
- 对于KMP算法的一点点理解(仅仅就是一点点)
- 我对搜索算法的一点点理解
- 如何更好的理解SVM算法
- Enum的一点点理解
- 一点点指针的理解
- 对于bresenham画圆算法的一点点理解
- maven的一点点的理解
- 对jsp的一点点理解
- 一点点arm bootloader的理解
- 关于latch的一点点理解
- 有关KMP的一点点理解
- 对指针的一点点理解
- 对static的一点点理解
- 对Thrift的一点点理解
- 对Thrift的一点点理解
- 对Dijkstra的一点点理解。
- [读书笔记]
- poj 3641 快速幂
- 嵌套全选jq代码
- 屏幕适配
- 二叉树中和为某一值的路径
- SVM算法的一点点理解
- NOIP提高组 树塔狂想曲
- NYOJ-117 求逆序数(离散化+树状数组)/(归并)
- Android面试题:横竖屏切换的生命周期
- [python]判断列表为空
- 字符串移位包含的问题
- CococaPods 前前后后
- 操作系统-处理机调度
- C++new分配内存空间