SVM算法的一点点理解

来源:互联网 发布:精准扶贫数据平台登录 编辑:程序博客网 时间:2024/05/19 19:57

最近几天看了下SVM算法,下面是我的个人理解。
SVM(支持向量机)是为了找一个超平面,使得几何间隔最大化,为了方便讨论,下面只考虑二维的情况,并且数据线性可分且数据的种类只有两种(即二分类问题),这样问题就是找一个直线使得这条直线能够把正负两类点分割,并且使得所有数据的最小几何间隔最大。
这里写图片描述
这里假设数据为(x1,y1),(x2,y2)…….(xn,yn)
可以把数据表示为坐标轴上的点。
我们定义函数间隔为yi*(wxi+b)
几何间隔为yi*(wxi+b)/||w||
我们定义最小的函数间隔为T1 = min(所以数据的函数间隔)
最小的几个间隔为T2 = min(所以数据的几何间隔)
那么SVM算法求的就是
max T2
st yi*(wxi+b)/||w||>=T2
可以写成
max T1/||w||
st yi*(wxi+b)/||w||>=T2/||w||
这里我们注意到函数间隔这个约束是可以任意改变的,比如缩小或者放大N倍,这里可以把T1置为1
则 max 1/||w||
st yi*(wxi+b)>=1

即 min ·1/2||w||^2
st yi*(wxi+b)>=1
运用朗格朗日
L(a,w,b) = 1/2||w||^2-sigma(ai(yi*(wxi+b)-1))
对w和b求偏导数带入得到对偶形势
这里写图片描述

下面我们来讨论下数据线性不可分的情况。
当数据在二维线性不可分的情况下映射到三维,或者更高维可能就线性可分了。
下面的两种核函数
这里写图片描述

还有一种处理线性不可分的方法-软间隔
我们允许分类器出现误差。
即 min 1/2||w||+C*sigma(mi)
st yi*(w*xi)>=1-mi
可以理解为给分类器一个犯错的允许,但是每个误差都是要付出代价的

同样运用拉格朗日可以得到对偶形式
这里写图片描述
其中
这里写图片描述

如何求得这个问题的解是NP难问题

但是微软研究院的以为大神的SMO算法可以逼近最优

这里简单讲讲SMO算法

我们选取a1,a2作为要更新的对象,那么根据限制条件sigma(ai*yi) = 0,和0<=ai<=c可以得到a1 = (s-a2*y2)*y1;并且得到a2的取值范围,带入原式就是个一元问题。幸运的是原函数是个凸函数,可以直接求极值。

这里a1,a2的选取可以运用启发式搜素,这里就不说了。

这里是代码。#include "stdafx.h"#include <opencv2/core/core.hpp>#include <opencv2/imgproc/imgproc.hpp>#include <opencv2/highgui/highgui.hpp>#include <iostream>#include <algorithm>#include <stdlib.h>#include <math.h>#include <time.h>#include <conio.h>#include <math.h>#include <dirent.h> #include <vector>using namespace cv;using namespace std;#define  X_LEN 32*30#define  S_LEN 64#define  TOT 64#define  C 1#define  EPS 1e-4#define W 3#define MAXN 80int x[S_LEN+1][X_LEN],y[S_LEN];double a[S_LEN],miss[S_LEN],b;double w[X_LEN],res[S_LEN];vector<string> ve[2];double pow2(double x){    return x*x;}int eq(double a){    if(fabs(a)<EPS) return 1;    else return 0;}double kernel(int s1,int s2){    /*int sum   = 0;    for(int i = 0;i<X_LEN;i++)    {        sum+=x[s1][i]*x[s2][i];    }    return sum;*/    double sum = 0;    for(int i = 0;i<X_LEN;i++)    {        sum+=pow2(x[s1][i]-x[s2][i]);    }    return exp(-sum/2*W*W);}void cal_w(){    for(int i = 0;i<X_LEN;i++)    {        double sum  = 0;        for(int j = 0;j<S_LEN;j++)        {             sum+=a[j]*y[j]*x[j][i];        }        w[i] = sum;    }}double cal_res(int s){    double sum =  0;    for(int i = 0;i<S_LEN;i++)    {        sum+=a[i]*y[i]*kernel(i,s);    }    return sum+b;}double cal_miss(int s){      return cal_res(s)-y[s];}int choose1(){/*  for(int i=0;i<S_LEN;i++)    {        if(a[i]>0&&a[i]<C&&!eq(y[i]*cal_res(i)-1)) return i;    }    for(int i=0;i<S_LEN;i++)    {        if(eq(a[i])&&y[i]*cal_res(i)-1<1-EPS) return i;        if(eq(a[i]-C)&&y[i]*cal_res(i)-1>1-EPS) return i;    }*/      return rand()%S_LEN;}int choose2(int index1){    /*int index2 = -1;    double maxnum = -1;    double m  = cal_miss(index1);    for(int i = 0;i<S_LEN;i++)    {        if(i==index1) continue;        if(fabs(m-cal_miss(i))>maxnum)            maxnum = fabs(m-cal_miss(i)),index2 = i;    }    */    //return index2;    int temp = rand()%S_LEN;    while(temp==index1) temp = rand()%S_LEN;    return temp;}void svm(){    int index1 = choose1();    int index2 = choose2(index1);    int i,j,k;    //printf("123ssdfsfsdf1231\n");    double sum = 0;    double L,H;    for(i=0;i<S_LEN;i++)    {        if(i!=index1&&i!=index2)        {            sum+=a[i]*y[i];        }    }    sum*=-1;//  printf("%lf\n",sum);    if(y[index1]==y[index2])    {        if(y[index1]==-1)        {            H = min(C,(a[index1]+a[index2]));            L = max(0,(a[index1]+a[index2])-C);        }        else        {            H = min(C,a[index1]+a[index2]);            L = max(0,a[index1]+a[index2]-C);        }    }    else    {        if(y[index2]==-1)        {            swap(index1,index2);        }        H = min(C,C+a[index2]-a[index1]);        L = max(0,a[index2]-a[index1]);    }    //printf("%lf %lf\n",L,H);    for(i=0;i<S_LEN;i++)    {        miss[i] = cal_miss(i);    }    double temp = kernel(index1,index1)+kernel(index2,index2)-2*kernel(index1,index2);    double temp_a = a[index2]+y[index2]*(miss[index1]-miss[index2])/temp;    //printf("%lf %lf %lf %d %d %lf\n",a[index2],miss[index1],miss[index2],index1,index2,temp_a);    if(temp_a>=L&&temp_a<=H) a[index2] = temp_a;    else    {        double v1 = cal_res(index1)-b-a[index1]*y[index1]*kernel(index1,index1)-a[index2]*y[index2]*kernel(index1,index2);        double v2 = cal_res(index2)-b-a[index1]*y[index1]*kernel(index2,index1)-a[index2]*y[index2]*kernel(index2,index2);        double s1 = 0.5*kernel(index1,index1)*pow2(sum-L*y[index2])+0.5*kernel(index2,index2)*pow2(L)                  +y[index2]*kernel(index1,index2)*(sum-L*y[index2])*L-(sum-L*y[index2])*y[index1]-                  L+v1*(sum-L*y[index2])+y[index2]*v2*L;        double s2 = 0.5*kernel(index1,index1)*pow2(sum-H*y[index2])+0.5*kernel(index2,index2)*pow2(H)                  +y[index2]*kernel(index1,index2)*(sum-H*y[index2])*H-(sum-H*y[index2])*y[index1]-                  H+v1*(sum-H*y[index2])+y[index2]*v2*H;        if(s1<s2) a[index2] = L;        else a[index2] = H;    }    a[index1] = (sum-y[index2]*a[index2])*y[index1];    //printf("%d %d\n",index1,index2);    if(a[index1]>0&&a[index1]<C)    {        //printf("11111111111\n");        b = y[index1]-cal_res(index1)+b;    }    else if(a[index2]>0&&a[index2]<C)    {    //  printf("222222222222\n");        b = y[index2]-cal_res(index2)+b;    }    else    {    //  printf("33333333333333333\n");        b  =  (y[index1]-cal_res(index1)+2*b+y[index2]-cal_res(index2))/2;    }  // printf("%lf\n",a[index2]);}int check(double &pre){    double sum  = 0;    for(int i = 0;i<S_LEN;i++)    {        for(int j = 0;j<S_LEN;j++)        {            sum+=a[i]*a[j]*y[i]*y[j]*kernel(i,j);        }    }    sum*=0.5;    for(int i = 0;i<S_LEN;i++)    {        sum-=a[i];    }    //printf("%lf %lf\n",pre,sum);    //printf("1111111,%lf\n",pre-sum);    if(pre-sum>EPS) {        pre = sum;        return 1;    }    else {        pre = sum;        return 0;    }}void readdir(){      DIR *directory_pointer;      directory_pointer = opendir("d://ccut//faces_4//an2i");      struct dirent *entry;       int i;      while((entry = readdir(directory_pointer))!=NULL)      {          string s = "d://ccut//faces_4//an2i//";          if((*entry).d_name[0]=='.') continue;          s+=(*entry).d_name;          ve[0].push_back(s);         // cout<<s<<endl;          //char name[MAXN];         /* for(i=0;i<s.length();i++)          {              name[i] = s[i];           }          name[s.length()] = 0;          */          /* DIR * directory;          directory = opendir(name);          struct dirent *entry_tmp;          string temp  = s;          while((entry_tmp = readdir(directory))!=NULL)          {               s = temp;               if((entry_tmp->d_name)[0]=='.') continue;               s+="//";               s+=entry_tmp->d_name;               path[p_flag] = s;               path_name[p_flag++]+= entry_tmp->d_name;           }                          closedir(directory);         */      }      directory_pointer = opendir("d://ccut//faces_4//at33");      while((entry = readdir(directory_pointer))!=NULL)      {          string s = "d://ccut//faces_4//at33//";          if((*entry).d_name[0]=='.') continue;          s+=(*entry).d_name;          ve[1].push_back(s);      }      closedir(directory_pointer);}void init(){    //printf("%d %d\n",ve[0].size(),ve[1].size());    //printf("1111111\n");    readdir();    //printf("1111111");    b  = 0;//  printf("%d %d\n",ve[0].size(),ve[1].size());    for(int i = 0;i<S_LEN;i++)    {        a[i] = 0;    }    Mat mat;    int cnt = 0;    for(int i = 0;i<ve[0].size();i++)    {         y[i] = 1;        mat=  imread(ve[0][i]);        cnt = 0;        for(int j = 0;j<mat.rows;j++)        {            for(int k = 0;k<mat.cols;k++)            {                x[i][cnt++] = (int)(mat.at<uchar>(j,k));            }        }    }    for(int i = 0;i<ve[1].size();i++)    {        cnt = 0;        y[32+i] = -1;        mat=  imread(ve[1][i]);        for(int j = 0;j<mat.rows;j++)        {            for(int k = 0;k<mat.cols;k++)            {                x[32+i][cnt++] = (int)(mat.at<uchar>(j,k));            }        }    }    //printf("%d\n",cnt);    /*mat[0] = imread("D:\\ccut\\faces_4\\an2i\\an2i_left_angry_open_4.pgm");    mat[1] = imread("D:\\ccut\\faces_4\\an2i\\an2i_right_angry_open_4.pgm");    mat[2] = imread("D:\\ccut\\faces_4\\an2i\\an2i_left_angry_sunglasses_4.pgm");    y[0] = 1;    y[1] = -1;    y[2] = 1;    for(int i = 0;i<S_LEN;i++)    {        int cnt = 0;        for(int j = 0;j<mat[i].rows;j++)        {            for(int k = 0;k<mat[i].cols;k++)            {                x[i][cnt++] = (int)(mat[i].at<uchar>(i,j));            }        }    }    /*    //printf("%d %d\n",mat1.rows,mat1.cols);    x[0][0] = 1,x[0][1] = 1,y[0] = 1;    x[1][0] = 1,x[1][1] = 2,y[1] = -1;    x[2][0] = 1,x[2][1] = 3,y[2] = 1;    x[3][0] = 1,x[3][1] = 4,y[3] = -1;    */}int _tmain(int argc, _TCHAR* argv[]){    srand(time(0));    init();          double pre = 0;    int flag = 0;    while(flag<=200)    {        double s = pre;        svm();        cal_w();        /*for(int i=0;i<S_LEN;i++)        {             printf("%lf ",a[i]);;        }        printf("\n");        for(int i = 0;i<X_LEN;i++)        {            printf("%lf ",w[i]);        }        printf("%lf\n",b);        printf("%lf\n",pre);        */        flag+=!check(pre);        printf("%d\n",flag);    }    //FILE *f = fopen("D:\\ccut\\faces_4\\an2i\\an2i_left_angry_open_4.pgm","rb");    //Mat mat1 = imread("D:\\ccut\\faces_4\\an2i\\an2i_left_angry_open_4.pgm");    //waitKey(111111);    for(int i = 0;i<=63;i++)    {        printf("%lf\n",cal_res(i));    }    //printf("%lf\n",cal_res(63));    return 0;}
0 0
原创粉丝点击