OpenCV的svm手写字检测

来源:互联网 发布:mac os x server 下载 编辑:程序博客网 时间:2024/05/01 03:59

转自http://blog.csdn.net/firefight/article/details/6452188

是MNIST手写数字图片库:http://code.google.com/p/supplement-of-the-mnist-database-of-handwritten-digits/downloads/list

其他方法:http://blog.csdn.net/onezeros/article/details/5672192

 

使用OPENCV训练手写数字识别分类器 

1,下载训练数据和测试数据文件,这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个
2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端
3,确定字符特征方式为最简单的8×8网格内的字符点数


4,创建SVM,训练并读取,结果如下
 1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
  10000个训练样本,测试数据正确率95.45%
  60000个训练样本,测试数据正确率97.67%

5,编写手写输入的GUI程序,并进行验证,效果还可以接受。

 

以下为主要代码,以供参考

(类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)

#include "stdafx.h"       #include <fstream>    #include "opencv2/opencv.hpp"    #include <vector>       using namespace std;   using namespace cv;      #define SHOW_PROCESS 0    #define ON_STUDY 0       class NumTrainData   {   public:       NumTrainData()       {           memset(data, 0, sizeof(data));           result = -1;       }   public:       float data[64];       int result;   };      vector<NumTrainData> buffer;   int featureLen = 64;      void swapBuffer(char* buf)   {       char temp;       temp = *(buf);       *buf = *(buf+3);       *(buf+3) = temp;          temp = *(buf+1);       *(buf+1) = *(buf+2);       *(buf+2) = temp;   }      void GetROI(Mat& src, Mat& dst)   {       int left, right, top, bottom;       left = src.cols;       right = 0;       top = src.rows;       bottom = 0;          //Get valid area        for(int i=0; i<src.rows; i++)       {           for(int j=0; j<src.cols; j++)           {               if(src.at<uchar>(i, j) > 0)               {                   if(j<left) left = j;                   if(j>right) right = j;                   if(i<top) top = i;                   if(i>bottom) bottom = i;               }           }       }          //Point center;        //center.x = (left + right) / 2;        //center.y = (top + bottom) / 2;           int width = right - left;       int height = bottom - top;       int len = (width < height) ? height : width;          //Create a squre        dst = Mat::zeros(len, len, CV_8UC1);          //Copy valid data to squre center        Rect dstRect((len - width)/2, (len - height)/2, width, height);       Rect srcRect(left, top, width, height);       Mat dstROI = dst(dstRect);       Mat srcROI = src(srcRect);       srcROI.copyTo(dstROI);   }      int ReadTrainData(int maxCount)   {       //Open image and label file        const char fileName[] = "../res/train-images.idx3-ubyte";       const char labelFileName[] = "../res/train-labels.idx1-ubyte";          ifstream lab_ifs(labelFileName, ios_base::binary);       ifstream ifs(fileName, ios_base::binary);          if( ifs.fail() == true )           return -1;          if( lab_ifs.fail() == true )           return -1;          //Read train data number and image rows / cols        char magicNum[4], ccount[4], crows[4], ccols[4];       ifs.read(magicNum, sizeof(magicNum));       ifs.read(ccount, sizeof(ccount));       ifs.read(crows, sizeof(crows));       ifs.read(ccols, sizeof(ccols));          int count, rows, cols;       swapBuffer(ccount);       swapBuffer(crows);       swapBuffer(ccols);          memcpy(&count, ccount, sizeof(count));       memcpy(&rows, crows, sizeof(rows));       memcpy(&cols, ccols, sizeof(cols));          //Just skip label header        lab_ifs.read(magicNum, sizeof(magicNum));       lab_ifs.read(ccount, sizeof(ccount));          //Create source and show image matrix        Mat src = Mat::zeros(rows, cols, CV_8UC1);       Mat temp = Mat::zeros(8, 8, CV_8UC1);       Mat img, dst;          char label = 0;       Scalar templateColor(255, 0, 255 );          NumTrainData rtd;          //int loop = 1000;        int total = 0;          while(!ifs.eof())       {           if(total >= count)               break;                      total++;           cout << total << endl;                      //Read label            lab_ifs.read(&label, 1);           label = label + '0';              //Read source data            ifs.read((char*)src.data, rows * cols);           GetROI(src, dst);      #if(SHOW_PROCESS)            //Too small to watch            img = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1);           resize(dst, img, img.size());              stringstream ss;           ss << "Number " << label;           string text = ss.str();           putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);              //imshow("img", img);    #endif               rtd.result = label;           resize(dst, temp, temp.size());           //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);               for(int i = 0; i<8; i++)           {               for(int j = 0; j<8; j++)               {                       rtd.data[ i*8 + j] = temp.at<uchar>(i, j);               }           }              buffer.push_back(rtd);              //if(waitKey(0)==27) //ESC to quit            //  break;               maxCount--;                      if(maxCount == 0)               break;       }          ifs.close();       lab_ifs.close();          return 0;   }      void newRtStudy(vector<NumTrainData>& trainData)   {       int testCount = trainData.size();          Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);       Mat res = Mat::zeros(testCount, 1, CV_32SC1);          for (int i= 0; i< testCount; i++)        {               NumTrainData td = trainData.at(i);           memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float));              res.at<unsigned int>(i, 0) = td.result;       }          /////////////START RT TRAINNING//////////////////        CvRTrees forest;       CvMat* var_importance = 0;          forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(),               CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));       forest.save( "new_rtrees.xml" );   }         int newRtPredict()   {       CvRTrees forest;       forest.load( "new_rtrees.xml" );          const char fileName[] = "../res/t10k-images.idx3-ubyte";       const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";          ifstream lab_ifs(labelFileName, ios_base::binary);       ifstream ifs(fileName, ios_base::binary);          if( ifs.fail() == true )           return -1;          if( lab_ifs.fail() == true )           return -1;          char magicNum[4], ccount[4], crows[4], ccols[4];       ifs.read(magicNum, sizeof(magicNum));       ifs.read(ccount, sizeof(ccount));       ifs.read(crows, sizeof(crows));       ifs.read(ccols, sizeof(ccols));          int count, rows, cols;       swapBuffer(ccount);       swapBuffer(crows);       swapBuffer(ccols);          memcpy(&count, ccount, sizeof(count));       memcpy(&rows, crows, sizeof(rows));       memcpy(&cols, ccols, sizeof(cols));          Mat src = Mat::zeros(rows, cols, CV_8UC1);       Mat temp = Mat::zeros(8, 8, CV_8UC1);       Mat m = Mat::zeros(1, featureLen, CV_32FC1);       Mat img, dst;          //Just skip label header        lab_ifs.read(magicNum, sizeof(magicNum));       lab_ifs.read(ccount, sizeof(ccount));          char label = 0;       Scalar templateColor(255, 0, 0);          NumTrainData rtd;          int right = 0, error = 0, total = 0;       int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;       while(ifs.good())       {           //Read label            lab_ifs.read(&label, 1);           label = label + '0';              //Read data            ifs.read((char*)src.data, rows * cols);           GetROI(src, dst);              //Too small to watch            img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);           resize(dst, img, img.size());              rtd.result = label;           resize(dst, temp, temp.size());           //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);            for(int i = 0; i<8; i++)           {               for(int j = 0; j<8; j++)               {                       m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);               }           }              if(total >= count)               break;              char ret = (char)forest.predict(m);               if(ret == label)           {               right++;               if(total <= 5000)                   right_1++;               else                   right_2++;           }           else           {               error++;               if(total <= 5000)                   error_1++;               else                   error_2++;           }              total++;      #if(SHOW_PROCESS)            stringstream ss;           ss << "Number " << label << ", predict " << ret;           string text = ss.str();           putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);              imshow("img", img);           if(waitKey(0)==27) //ESC to quit                break;   #endif           }          ifs.close();       lab_ifs.close();          stringstream ss;       ss << "Total " << total << ", right " << right <<", error " << error;       string text = ss.str();       putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);       imshow("img", img);       waitKey(0);          return 0;   }      void newSvmStudy(vector<NumTrainData>& trainData)   {       int testCount = trainData.size();          Mat m = Mat::zeros(1, featureLen, CV_32FC1);       Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);       Mat res = Mat::zeros(testCount, 1, CV_32SC1);          for (int i= 0; i< testCount; i++)        {               NumTrainData td = trainData.at(i);           memcpy(m.data, td.data, featureLen*sizeof(float));           normalize(m, m);           memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));              res.at<unsigned int>(i, 0) = td.result;       }          /////////////START SVM TRAINNING//////////////////        CvSVM svm = CvSVM();        CvSVMParams param;        CvTermCriteria criteria;          criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);        param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);           svm.train(data, res, Mat(), Mat(), param);       svm.save( "SVM_DATA.xml" );   }         int newSvmPredict()   {       CvSVM svm = CvSVM();        svm.load( "SVM_DATA.xml" );          const char fileName[] = "../res/t10k-images.idx3-ubyte";       const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";          ifstream lab_ifs(labelFileName, ios_base::binary);       ifstream ifs(fileName, ios_base::binary);          if( ifs.fail() == true )           return -1;          if( lab_ifs.fail() == true )           return -1;          char magicNum[4], ccount[4], crows[4], ccols[4];       ifs.read(magicNum, sizeof(magicNum));       ifs.read(ccount, sizeof(ccount));       ifs.read(crows, sizeof(crows));       ifs.read(ccols, sizeof(ccols));          int count, rows, cols;       swapBuffer(ccount);       swapBuffer(crows);       swapBuffer(ccols);          memcpy(&count, ccount, sizeof(count));       memcpy(&rows, crows, sizeof(rows));       memcpy(&cols, ccols, sizeof(cols));          Mat src = Mat::zeros(rows, cols, CV_8UC1);       Mat temp = Mat::zeros(8, 8, CV_8UC1);       Mat m = Mat::zeros(1, featureLen, CV_32FC1);       Mat img, dst;          //Just skip label header        lab_ifs.read(magicNum, sizeof(magicNum));       lab_ifs.read(ccount, sizeof(ccount));          char label = 0;       Scalar templateColor(255, 0, 0);          NumTrainData rtd;          int right = 0, error = 0, total = 0;       int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;       while(ifs.good())       {           //Read label            lab_ifs.read(&label, 1);           label = label + '0';              //Read data            ifs.read((char*)src.data, rows * cols);           GetROI(src, dst);              //Too small to watch            img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);           resize(dst, img, img.size());              rtd.result = label;           resize(dst, temp, temp.size());           //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);            for(int i = 0; i<8; i++)           {               for(int j = 0; j<8; j++)               {                       m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);               }           }              if(total >= count)               break;              normalize(m, m);           char ret = (char)svm.predict(m);               if(ret == label)           {               right++;               if(total <= 5000)                   right_1++;               else                   right_2++;           }           else           {               error++;               if(total <= 5000)                   error_1++;               else                   error_2++;           }              total++;      #if(SHOW_PROCESS)            stringstream ss;           ss << "Number " << label << ", predict " << ret;           string text = ss.str();           putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);              imshow("img", img);           if(waitKey(0)==27) //ESC to quit                break;   #endif           }          ifs.close();       lab_ifs.close();          stringstream ss;       ss << "Total " << total << ", right " << right <<", error " << error;       string text = ss.str();       putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);       imshow("img", img);       waitKey(0);          return 0;   }      int main( int argc, char *argv[] )   {   #if(ON_STUDY)        int maxCount = 60000;       ReadTrainData(maxCount);          //newRtStudy(buffer);        newSvmStudy(buffer);   #else        //newRtPredict();        newSvmPredict();   #endif        return 0;   }