基于opencv的SVM实现

来源:互联网 发布:最新网络流行语段子 编辑:程序博客网 时间:2024/05/16 01:46

最近在研究关于图像的分类,其中要用到SVM,便在网上找了些代码,自己进行了一些修改,其中参考了部分其他人的博客。

程序代码如下:

#include <iostream>#include <opencv2/core/core.hpp>#include <opencv2/highgui/highgui.hpp>#include <opencv2/ml/ml.hpp>#include <CTYPE.H>  #defineNTRAINING_SAMPLES100// Number of training samples per class#define FRAC_LINEAR_SEP0.9f    // Fraction of samples which compose the linear separable partusing namespace cv;using namespace std;int main(int argc, char* argv[])  {  int size = 400; // height and widht of image  const int s = 1000; // number of data  int i, j,sv_num;  IplImage* img;  CvSVM svm ;  CvSVMParams param;  CvTermCriteria criteria; // 停止迭代标准  CvRNG rng = cvRNG();  CvPoint pts[s]; // 定义1000个点  float data[s*2]; // 点的坐标  int res[s]; // 点的类别  CvMat data_mat, res_mat;  CvScalar rcolor;  const float* support;  // 图像区域的初始化  img = cvCreateImage(cvSize(size,size),IPL_DEPTH_8U,3);  cvZero(img);  // 学习数据的生成  for (i=0; i<s;++i)  {  pts[i].x = cvRandInt(&rng)%size;  pts[i].y = cvRandInt(&rng)%size;  if (pts[i].y>50*cos(pts[i].x*CV_PI/100)+200)  {  cvLine(img,cvPoint(pts[i].x-2,pts[i].y-2),cvPoint(pts[i].x+2,pts[i].y+2),CV_RGB(255,0,0));  cvLine(img,cvPoint(pts[i].x+2,pts[i].y-2),cvPoint(pts[i].x-2,pts[i].y+2),CV_RGB(255,0,0));  res[i]=1;  }  else  {  if (pts[i].x>200)  {  cvLine(img,cvPoint(pts[i].x-2,pts[i].y-2),cvPoint(pts[i].x+2,pts[i].y+2),CV_RGB(0,255,0));  cvLine(img,cvPoint(pts[i].x+2,pts[i].y-2),cvPoint(pts[i].x-2,pts[i].y+2),CV_RGB(0,255,0));  res[i]=2;  }  else  {  cvLine(img,cvPoint(pts[i].x-2,pts[i].y-2),cvPoint(pts[i].x+2,pts[i].y+2),CV_RGB(0,0,255));  cvLine(img,cvPoint(pts[i].x+2,pts[i].y-2),cvPoint(pts[i].x-2,pts[i].y+2),CV_RGB(0,0,255));  res[i]=3;  }  }  }  // 学习数据的现实  cvNamedWindow("SVM训练样本空间及分类",CV_WINDOW_AUTOSIZE);  cvShowImage("SVM训练样本空间及分类",img);  cvWaitKey(0);  // 学习参数的生成  for (i=0;i<s;++i)  {  data[i*2] = float(pts[i].x)/size;  data[i*2+1] = float(pts[i].y)/size;  }  cvInitMatHeader(&data_mat,s,2,CV_32FC1,data);  cvInitMatHeader(&res_mat,s,1,CV_32SC1,res);  criteria = cvTermCriteria(CV_TERMCRIT_EPS,1000,FLT_EPSILON);  param = CvSVMParams(CvSVM::C_SVC,CvSVM::RBF,10.0,8.0,1.0,10.0,0.5,0.1,NULL,criteria);  svm.train(&data_mat,&res_mat,NULL,NULL,param);  // 学习结果绘图  for (i=0;i<size;i++)  {  for (j=0;j<size;j++)  {  CvMat m;  float ret = 0.0;  float a[] = {float(j)/size,float(i)/size};  cvInitMatHeader(&m,1,2,CV_32FC1,a);  ret = svm.predict(&m);  switch((int)ret)  {  case 1:  rcolor = CV_RGB(100,0,0);  break;  case 2:  rcolor = CV_RGB(0,100,0);  break;  case 3:  rcolor = CV_RGB(0,0,100);  break;  }  cvSet2D(img,i,j,rcolor);  }  }  // 为了显示学习结果,通过对输入图像区域的所有像素(特征向量)进行分类,然后对输入的像素用所属颜色等级的颜色绘图  for(i=0;i<s;++i)  {  CvScalar rcolor;  switch(res[i])  {  case 1:  rcolor = CV_RGB(255,0,0);  break;  case 2:  rcolor = CV_RGB(0,255,0);  break;  case 3:  rcolor = CV_RGB(0,0,255);  break;  }  cvLine(img,cvPoint(pts[i].x-2,pts[i].y-2),cvPoint(pts[i].x+2,pts[i].y+2),rcolor);  cvLine(img,cvPoint(pts[i].x+2,pts[i].y-2),cvPoint(pts[i].x-2,pts[i].y+2),rcolor);             }  // 支持向量的绘制  sv_num = svm.get_support_vector_count();  for (i=0; i<sv_num;++i)  {  support = svm.get_support_vector(i);  cvCircle(img,cvPoint((int)(support[0]*size),(int)(support[i]*size)),5,CV_RGB(200,200,200));  }  cvNamedWindow("SVM",CV_WINDOW_AUTOSIZE);  cvShowImage("SVM分类结果及支持向量",img);  cvWaitKey(0);  cvDestroyWindow("SVM");  cvReleaseImage(&img);  return 0;  }  

效果图:

1 0