学习OpenCV——SVM 手写数字检测

来源:互联网 发布:百战天虫豪华版java 编辑:程序博客网 时间:2024/04/29 00:53

转自http://blog.csdn.net/firefight/article/details/6452188

是MNIST手写数字图片库:http://code.google.com/p/supplement-of-the-mnist-database-of-handwritten-digits/downloads/list

其他方法:http://blog.csdn.net/onezeros/article/details/5672192

 

 

使用OPENCV训练手写数字识别分类器 

1,下载训练数据和测试数据文件,这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个
2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端
3,确定字符特征方式为最简单的8×8网格内的字符点数


4,创建SVM,训练并读取,结果如下
 1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
  10000个训练样本,测试数据正确率95.45%
  60000个训练样本,测试数据正确率97.67%

5,编写手写输入的GUI程序,并进行验证,效果还可以接受。

以下为主要代码,以供参考

(类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)

 

#include "stdafx.h"

#include <fstream>
#include "opencv2/opencv.hpp"
#include <vector>

using namespace std;
using namespace cv;

#define SHOW_PROCESS 0
#define ON_STUDY 0

class NumTrainData
{
public:
 NumTrainData()
 {
  memset(data, 0, sizeof(data));
  result = -1;
 }
public:
 float data[64];
 int result;
};

vector<NumTrainData> buffer;
int featureLen = 64;

void swapBuffer(char* buf)
{
 char temp;
 temp = *(buf);
 *buf = *(buf+3);
 *(buf+3) = temp;

 temp = *(buf+1);
 *(buf+1) = *(buf+2);
 *(buf+2) = temp;
}

void GetROI(Mat& src, Mat& dst)
{
 int left, right, top, bottom;
 left = src.cols;
 right = 0;
 top = src.rows;
 bottom = 0;

 //Get valid area
 for(int i=0; i<src.rows; i++)
 {
  for(int j=0; j<src.cols; j++)
  {
   if(src.at<uchar>(i, j) > 0)
   {
    if(j<left) left = j;
    if(j>right) right = j;
    if(i<top) top = i;
    if(i>bottom) bottom = i;
   }
  }
 }

 //Point center;
 //center.x = (left + right) / 2;
 //center.y = (top + bottom) / 2;

 int width = right - left;
 int height = bottom - top;
 int len = (width < height) ? height : width;

 //Create a squre
 dst = Mat::zeros(len, len, CV_8UC1);

 //Copy valid data to squre center
 Rect dstRect((len - width)/2, (len - height)/2, width, height);
 Rect srcRect(left, top, width, height);
 Mat dstROI = dst(dstRect);
 Mat srcROI = src(srcRect);
 srcROI.copyTo(dstROI);
}

int ReadTrainData(int maxCount)
{
 //Open image and label file
 const char fileName[] = "../res/train-images.idx3-ubyte";
 const char labelFileName[] = "../res/train-labels.idx1-ubyte";

 ifstream lab_ifs(labelFileName, ios_base::binary);
 ifstream ifs(fileName, ios_base::binary);

 if( ifs.fail() == true )
  return -1;

 if( lab_ifs.fail() == true )
  return -1;

 //Read train data number and image rows / cols
 char magicNum[4], ccount[4], crows[4], ccols[4];
 ifs.read(magicNum, sizeof(magicNum));
 ifs.read(ccount, sizeof(ccount));
 ifs.read(crows, sizeof(crows));
 ifs.read(ccols, sizeof(ccols));

 int count, rows, cols;
 swapBuffer(ccount);
 swapBuffer(crows);
 swapBuffer(ccols);

 memcpy(&count, ccount, sizeof(count));
 memcpy(&rows, crows, sizeof(rows));
 memcpy(&cols, ccols, sizeof(cols));

 //Just skip label header
 lab_ifs.read(magicNum, sizeof(magicNum));
 lab_ifs.read(ccount, sizeof(ccount));

 //Create source and show image matrix
 Mat src = Mat::zeros(rows, cols, CV_8UC1);
 Mat temp = Mat::zeros(8, 8, CV_8UC1);
 Mat img, dst;

 char label = 0;
 Scalar templateColor(255, 0, 255 );

 NumTrainData rtd;

 //int loop = 1000;
 int total = 0;

 while(!ifs.eof())
 {
  if(total >= count)
   break;
  
  total++;
  cout << total << endl;
  
  //Read label
  lab_ifs.read(&label, 1);
  label = label + '0';

  //Read source data
  ifs.read((char*)src.data, rows * cols);
  GetROI(src, dst);

#if(SHOW_PROCESS)
  //Too small to watch
  img = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1);
  resize(dst, img, img.size());

  stringstream ss;
  ss << "Number " << label;
  string text = ss.str();
  putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);

  //imshow("img", img);
#endif

  rtd.result = label;
  resize(dst, temp, temp.size());
  //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);

  for(int i = 0; i<8; i++)
  {
   for(int j = 0; j<8; j++)
   {
     rtd.data[ i*8 + j] = temp.at<uchar>(i, j);
   }
  }

  buffer.push_back(rtd);

  //if(waitKey(0)==27) //ESC to quit
  // break;

  maxCount--;
  
  if(maxCount == 0)
   break;
 }

 ifs.close();
 lab_ifs.close();

 return 0;
}

void newRtStudy(vector<NumTrainData>& trainData)
{
 int testCount = trainData.size();

 Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
 Mat res = Mat::zeros(testCount, 1, CV_32SC1);

 for (int i= 0; i< testCount; i++)
 {

  NumTrainData td = trainData.at(i);
  memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float));

  res.at<unsigned int>(i, 0) = td.result;
 }

 /////////////START RT TRAINNING//////////////////
    CvRTrees forest;
    CvMat* var_importance = 0;

    forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(),
            CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));
 forest.save( "new_rtrees.xml" );
}


int newRtPredict()
{
    CvRTrees forest;
 forest.load( "new_rtrees.xml" );

 const char fileName[] = "../res/t10k-images.idx3-ubyte";
 const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";

 ifstream lab_ifs(labelFileName, ios_base::binary);
 ifstream ifs(fileName, ios_base::binary);

 if( ifs.fail() == true )
  return -1;

 if( lab_ifs.fail() == true )
  return -1;

 char magicNum[4], ccount[4], crows[4], ccols[4];
 ifs.read(magicNum, sizeof(magicNum));
 ifs.read(ccount, sizeof(ccount));
 ifs.read(crows, sizeof(crows));
 ifs.read(ccols, sizeof(ccols));

 int count, rows, cols;
 swapBuffer(ccount);
 swapBuffer(crows);
 swapBuffer(ccols);

 memcpy(&count, ccount, sizeof(count));
 memcpy(&rows, crows, sizeof(rows));
 memcpy(&cols, ccols, sizeof(cols));

 Mat src = Mat::zeros(rows, cols, CV_8UC1);
 Mat temp = Mat::zeros(8, 8, CV_8UC1);
 Mat m = Mat::zeros(1, featureLen, CV_32FC1);
 Mat img, dst;

 //Just skip label header
 lab_ifs.read(magicNum, sizeof(magicNum));
 lab_ifs.read(ccount, sizeof(ccount));

 char label = 0;
 Scalar templateColor(255, 0, 0);

 NumTrainData rtd;

 int right = 0, error = 0, total = 0;
 int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
 while(ifs.good())
 {
  //Read label
  lab_ifs.read(&label, 1);
  label = label + '0';

  //Read data
  ifs.read((char*)src.data, rows * cols);
  GetROI(src, dst);

  //Too small to watch
  img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
  resize(dst, img, img.size());

  rtd.result = label;
  resize(dst, temp, temp.size());
  //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
  for(int i = 0; i<8; i++)
  {
   for(int j = 0; j<8; j++)
   {
     m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
   }
  }

  if(total >= count)
   break;

  char ret = (char)forest.predict(m);

  if(ret == label)
  {
   right++;
   if(total <= 5000)
    right_1++;
   else
    right_2++;
  }
  else
  {
   error++;
   if(total <= 5000)
    error_1++;
   else
    error_2++;
  }

  total++;

#if(SHOW_PROCESS)
  stringstream ss;
  ss << "Number " << label << ", predict " << ret;
  string text = ss.str();
  putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);

  imshow("img", img);
  if(waitKey(0)==27) //ESC to quit
   break;
#endif

 }

 ifs.close();
 lab_ifs.close();

 stringstream ss;
 ss << "Total " << total << ", right " << right <<", error " << error;
 string text = ss.str();
 putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
 imshow("img", img);
 waitKey(0);

 return 0;
}

void newSvmStudy(vector<NumTrainData>& trainData)
{
 int testCount = trainData.size();

 Mat m = Mat::zeros(1, featureLen, CV_32FC1);
 Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
 Mat res = Mat::zeros(testCount, 1, CV_32SC1);

 for (int i= 0; i< testCount; i++)
 {

  NumTrainData td = trainData.at(i);
  memcpy(m.data, td.data, featureLen*sizeof(float));
  normalize(m, m);
  memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));

  res.at<unsigned int>(i, 0) = td.result;
 }

 /////////////START SVM TRAINNING//////////////////
 CvSVM svm = CvSVM();
 CvSVMParams param;
 CvTermCriteria criteria;

 criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);
 param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);

 svm.train(data, res, Mat(), Mat(), param);
 svm.save( "SVM_DATA.xml" );
}


int newSvmPredict()
{
 CvSVM svm = CvSVM();
 svm.load( "SVM_DATA.xml" );

 const char fileName[] = "../res/t10k-images.idx3-ubyte";
 const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";

 ifstream lab_ifs(labelFileName, ios_base::binary);
 ifstream ifs(fileName, ios_base::binary);

 if( ifs.fail() == true )
  return -1;

 if( lab_ifs.fail() == true )
  return -1;

 char magicNum[4], ccount[4], crows[4], ccols[4];
 ifs.read(magicNum, sizeof(magicNum));
 ifs.read(ccount, sizeof(ccount));
 ifs.read(crows, sizeof(crows));
 ifs.read(ccols, sizeof(ccols));

 int count, rows, cols;
 swapBuffer(ccount);
 swapBuffer(crows);
 swapBuffer(ccols);

 memcpy(&count, ccount, sizeof(count));
 memcpy(&rows, crows, sizeof(rows));
 memcpy(&cols, ccols, sizeof(cols));

 Mat src = Mat::zeros(rows, cols, CV_8UC1);
 Mat temp = Mat::zeros(8, 8, CV_8UC1);
 Mat m = Mat::zeros(1, featureLen, CV_32FC1);
 Mat img, dst;

 //Just skip label header
 lab_ifs.read(magicNum, sizeof(magicNum));
 lab_ifs.read(ccount, sizeof(ccount));

 char label = 0;
 Scalar templateColor(255, 0, 0);

 NumTrainData rtd;

 int right = 0, error = 0, total = 0;
 int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
 while(ifs.good())
 {
  //Read label
  lab_ifs.read(&label, 1);
  label = label + '0';

  //Read data
  ifs.read((char*)src.data, rows * cols);
  GetROI(src, dst);

  //Too small to watch
  img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
  resize(dst, img, img.size());

  rtd.result = label;
  resize(dst, temp, temp.size());
  //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
  for(int i = 0; i<8; i++)
  {
   for(int j = 0; j<8; j++)
   {
     m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
   }
  }

  if(total >= count)
   break;

  normalize(m, m);
  char ret = (char)svm.predict(m);

  if(ret == label)
  {
   right++;
   if(total <= 5000)
    right_1++;
   else
    right_2++;
  }
  else
  {
   error++;
   if(total <= 5000)
    error_1++;
   else
    error_2++;
  }

  total++;

#if(SHOW_PROCESS)
  stringstream ss;
  ss << "Number " << label << ", predict " << ret;
  string text = ss.str();
  putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);

  imshow("img", img);
  if(waitKey(0)==27) //ESC to quit
   break;
#endif

 }

 ifs.close();
 lab_ifs.close();

 stringstream ss;
 ss << "Total " << total << ", right " << right <<", error " << error;
 string text = ss.str();
 putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
 imshow("img", img);
 waitKey(0);

 return 0;
}

int main( int argc, char *argv[] )
{
#if(ON_STUDY)
 int maxCount = 60000;
 ReadTrainData(maxCount);

 //newRtStudy(buffer);
 newSvmStudy(buffer);
#else
 //newRtPredict();
 newSvmPredict();
#endif
 return 0;
}

 

原创粉丝点击