使用OPENCV训练手写数字识别分类器

来源:互联网 发布:xbrower连接linux桌面 编辑:程序博客网 时间:2024/05/18 03:57

【原文:http://blog.csdn.net/firefight/article/details/6452188】

使用OpenCV训练手写数字识别分类器 

1,下载训练数据和测试数据文件,这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个
2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端
3,确定字符特征方式为最简单的8×8网格内的字符点数


4,创建SVM,训练并读取,结果如下
 1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
  10000个训练样本,测试数据正确率95.45%
  60000个训练样本,测试数据正确率97.67%

5,编写手写输入的GUI程序,并进行验证,效果还可以接受。

 

以下为主要代码,以供参考

(类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)

 

[cpp] view plaincopy
  1. #include "stdafx.h"  
  2.   
  3. #include <fstream>  
  4. #include "opencv2/opencv.hpp"  
  5. #include <vector>  
  6.   
  7. using namespace std;  
  8. using namespace cv;  
  9.   
  10. #define SHOW_PROCESS 0  
  11. #define ON_STUDY 0  
  12.   
  13. class NumTrainData  
  14. {  
  15. public:  
  16.     NumTrainData()  
  17.     {  
  18.         memset(data, 0, sizeof(data));  
  19.         result = -1;  
  20.     }  
  21. public:  
  22.     float data[64];  
  23.     int result;  
  24. };  
  25.   
  26. vector<NumTrainData> buffer;  
  27. int featureLen = 64;  
  28.   
  29. void swapBuffer(char* buf)  
  30. {  
  31.     char temp;  
  32.     temp = *(buf);  
  33.     *buf = *(buf+3);  
  34.     *(buf+3) = temp;  
  35.   
  36.     temp = *(buf+1);  
  37.     *(buf+1) = *(buf+2);  
  38.     *(buf+2) = temp;  
  39. }  
  40.   
  41. void GetROI(Mat& src, Mat& dst)  
  42. {  
  43.     int left, right, top, bottom;  
  44.     left = src.cols;  
  45.     right = 0;  
  46.     top = src.rows;  
  47.     bottom = 0;  
  48.   
  49.     //Get valid area  
  50.     for(int i=0; i<src.rows; i++)  
  51.     {  
  52.         for(int j=0; j<src.cols; j++)  
  53.         {  
  54.             if(src.at<uchar>(i, j) > 0)  
  55.             {  
  56.                 if(j<left) left = j;  
  57.                 if(j>right) right = j;  
  58.                 if(i<top) top = i;  
  59.                 if(i>bottom) bottom = i;  
  60.             }  
  61.         }  
  62.     }  
  63.   
  64.     //Point center;  
  65.     //center.x = (left + right) / 2;  
  66.     //center.y = (top + bottom) / 2;  
  67.   
  68.     int width = right - left;  
  69.     int height = bottom - top;  
  70.     int len = (width < height) ? height : width;  
  71.   
  72.     //Create a squre  
  73.     dst = Mat::zeros(len, len, CV_8UC1);  
  74.   
  75.     //Copy valid data to squre center  
  76.     Rect dstRect((len - width)/2, (len - height)/2, width, height);  
  77.     Rect srcRect(left, top, width, height);  
  78.     Mat dstROI = dst(dstRect);  
  79.     Mat srcROI = src(srcRect);  
  80.     srcROI.copyTo(dstROI);  
  81. }  
  82.   
  83. int ReadTrainData(int maxCount)  
  84. {  
  85.     //Open image and label file  
  86.     const char fileName[] = "../res/train-images.idx3-ubyte";  
  87.     const char labelFileName[] = "../res/train-labels.idx1-ubyte";  
  88.   
  89.     ifstream lab_ifs(labelFileName, ios_base::binary);  
  90.     ifstream ifs(fileName, ios_base::binary);  
  91.   
  92.     if( ifs.fail() == true )  
  93.         return -1;  
  94.   
  95.     if( lab_ifs.fail() == true )  
  96.         return -1;  
  97.   
  98.     //Read train data number and image rows / cols  
  99.     char magicNum[4], ccount[4], crows[4], ccols[4];  
  100.     ifs.read(magicNum, sizeof(magicNum));  
  101.     ifs.read(ccount, sizeof(ccount));  
  102.     ifs.read(crows, sizeof(crows));  
  103.     ifs.read(ccols, sizeof(ccols));  
  104.   
  105.     int count, rows, cols;  
  106.     swapBuffer(ccount);  
  107.     swapBuffer(crows);  
  108.     swapBuffer(ccols);  
  109.   
  110.     memcpy(&count, ccount, sizeof(count));  
  111.     memcpy(&rows, crows, sizeof(rows));  
  112.     memcpy(&cols, ccols, sizeof(cols));  
  113.   
  114.     //Just skip label header  
  115.     lab_ifs.read(magicNum, sizeof(magicNum));  
  116.     lab_ifs.read(ccount, sizeof(ccount));  
  117.   
  118.     //Create source and show image matrix  
  119.     Mat src = Mat::zeros(rows, cols, CV_8UC1);  
  120.     Mat temp = Mat::zeros(8, 8, CV_8UC1);  
  121.     Mat img, dst;  
  122.   
  123.     char label = 0;  
  124.     Scalar templateColor(255, 0, 255 );  
  125.   
  126.     NumTrainData rtd;  
  127.   
  128.     //int loop = 1000;  
  129.     int total = 0;  
  130.   
  131.     while(!ifs.eof())  
  132.     {  
  133.         if(total >= count)  
  134.             break;  
  135.           
  136.         total++;  
  137.         cout << total << endl;  
  138.           
  139.         //Read label  
  140.         lab_ifs.read(&label, 1);  
  141.         label = label + '0';  
  142.   
  143.         //Read source data  
  144.         ifs.read((char*)src.data, rows * cols);  
  145.         GetROI(src, dst);  
  146.   
  147. #if(SHOW_PROCESS)  
  148.         //Too small to watch  
  149.         img = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1);  
  150.         resize(dst, img, img.size());  
  151.   
  152.         stringstream ss;  
  153.         ss << "Number " << label;  
  154.         string text = ss.str();  
  155.         putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
  156.   
  157.         //imshow("img", img);  
  158. #endif  
  159.   
  160.         rtd.result = label;  
  161.         resize(dst, temp, temp.size());  
  162.         //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);  
  163.   
  164.         for(int i = 0; i<8; i++)  
  165.         {  
  166.             for(int j = 0; j<8; j++)  
  167.             {  
  168.                     rtd.data[ i*8 + j] = temp.at<uchar>(i, j);  
  169.             }  
  170.         }  
  171.   
  172.         buffer.push_back(rtd);  
  173.   
  174.         //if(waitKey(0)==27) //ESC to quit  
  175.         //  break;  
  176.   
  177.         maxCount--;  
  178.           
  179.         if(maxCount == 0)  
  180.             break;  
  181.     }  
  182.   
  183.     ifs.close();  
  184.     lab_ifs.close();  
  185.   
  186.     return 0;  
  187. }  
  188.   
  189. void newRtStudy(vector<NumTrainData>& trainData)  
  190. {  
  191.     int testCount = trainData.size();  
  192.   
  193.     Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);  
  194.     Mat res = Mat::zeros(testCount, 1, CV_32SC1);  
  195.   
  196.     for (int i= 0; i< testCount; i++)   
  197.     {   
  198.   
  199.         NumTrainData td = trainData.at(i);  
  200.         memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float));  
  201.   
  202.         res.at<unsigned int>(i, 0) = td.result;  
  203.     }  
  204.   
  205.     /////////////START RT TRAINNING//////////////////  
  206.     CvRTrees forest;  
  207.     CvMat* var_importance = 0;  
  208.   
  209.     forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(),  
  210.             CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));  
  211.     forest.save( "new_rtrees.xml" );  
  212. }  
  213.   
  214.   
  215. int newRtPredict()  
  216. {  
  217.     CvRTrees forest;  
  218.     forest.load( "new_rtrees.xml" );  
  219.   
  220.     const char fileName[] = "../res/t10k-images.idx3-ubyte";  
  221.     const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";  
  222.   
  223.     ifstream lab_ifs(labelFileName, ios_base::binary);  
  224.     ifstream ifs(fileName, ios_base::binary);  
  225.   
  226.     if( ifs.fail() == true )  
  227.         return -1;  
  228.   
  229.     if( lab_ifs.fail() == true )  
  230.         return -1;  
  231.   
  232.     char magicNum[4], ccount[4], crows[4], ccols[4];  
  233.     ifs.read(magicNum, sizeof(magicNum));  
  234.     ifs.read(ccount, sizeof(ccount));  
  235.     ifs.read(crows, sizeof(crows));  
  236.     ifs.read(ccols, sizeof(ccols));  
  237.   
  238.     int count, rows, cols;  
  239.     swapBuffer(ccount);  
  240.     swapBuffer(crows);  
  241.     swapBuffer(ccols);  
  242.   
  243.     memcpy(&count, ccount, sizeof(count));  
  244.     memcpy(&rows, crows, sizeof(rows));  
  245.     memcpy(&cols, ccols, sizeof(cols));  
  246.   
  247.     Mat src = Mat::zeros(rows, cols, CV_8UC1);  
  248.     Mat temp = Mat::zeros(8, 8, CV_8UC1);  
  249.     Mat m = Mat::zeros(1, featureLen, CV_32FC1);  
  250.     Mat img, dst;  
  251.   
  252.     //Just skip label header  
  253.     lab_ifs.read(magicNum, sizeof(magicNum));  
  254.     lab_ifs.read(ccount, sizeof(ccount));  
  255.   
  256.     char label = 0;  
  257.     Scalar templateColor(255, 0, 0);  
  258.   
  259.     NumTrainData rtd;  
  260.   
  261.     int right = 0, error = 0, total = 0;  
  262.     int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;  
  263.     while(ifs.good())  
  264.     {  
  265.         //Read label  
  266.         lab_ifs.read(&label, 1);  
  267.         label = label + '0';  
  268.   
  269.         //Read data  
  270.         ifs.read((char*)src.data, rows * cols);  
  271.         GetROI(src, dst);  
  272.   
  273.         //Too small to watch  
  274.         img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);  
  275.         resize(dst, img, img.size());  
  276.   
  277.         rtd.result = label;  
  278.         resize(dst, temp, temp.size());  
  279.         //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);  
  280.         for(int i = 0; i<8; i++)  
  281.         {  
  282.             for(int j = 0; j<8; j++)  
  283.             {  
  284.                     m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);  
  285.             }  
  286.         }  
  287.   
  288.         if(total >= count)  
  289.             break;  
  290.   
  291.         char ret = (char)forest.predict(m);   
  292.   
  293.         if(ret == label)  
  294.         {  
  295.             right++;  
  296.             if(total <= 5000)  
  297.                 right_1++;  
  298.             else  
  299.                 right_2++;  
  300.         }  
  301.         else  
  302.         {  
  303.             error++;  
  304.             if(total <= 5000)  
  305.                 error_1++;  
  306.             else  
  307.                 error_2++;  
  308.         }  
  309.   
  310.         total++;  
  311.   
  312. #if(SHOW_PROCESS)  
  313.         stringstream ss;  
  314.         ss << "Number " << label << ", predict " << ret;  
  315.         string text = ss.str();  
  316.         putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
  317.   
  318.         imshow("img", img);  
  319.         if(waitKey(0)==27) //ESC to quit  
  320.             break;  
  321. #endif  
  322.   
  323.     }  
  324.   
  325.     ifs.close();  
  326.     lab_ifs.close();  
  327.   
  328.     stringstream ss;  
  329.     ss << "Total " << total << ", right " << right <<", error " << error;  
  330.     string text = ss.str();  
  331.     putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
  332.     imshow("img", img);  
  333.     waitKey(0);  
  334.   
  335.     return 0;  
  336. }  
  337.   
  338. void newSvmStudy(vector<NumTrainData>& trainData)  
  339. {  
  340.     int testCount = trainData.size();  
  341.   
  342.     Mat m = Mat::zeros(1, featureLen, CV_32FC1);  
  343.     Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);  
  344.     Mat res = Mat::zeros(testCount, 1, CV_32SC1);  
  345.   
  346.     for (int i= 0; i< testCount; i++)   
  347.     {   
  348.   
  349.         NumTrainData td = trainData.at(i);  
  350.         memcpy(m.data, td.data, featureLen*sizeof(float));  
  351.         normalize(m, m);  
  352.         memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));  
  353.   
  354.         res.at<unsigned int>(i, 0) = td.result;  
  355.     }  
  356.   
  357.     /////////////START SVM TRAINNING//////////////////  
  358.     CvSVM svm = CvSVM();   
  359.     CvSVMParams param;   
  360.     CvTermCriteria criteria;  
  361.   
  362.     criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);   
  363.     param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);   
  364.   
  365.     svm.train(data, res, Mat(), Mat(), param);  
  366.     svm.save( "SVM_DATA.xml" );  
  367. }  
  368.   
  369.   
  370. int newSvmPredict()  
  371. {  
  372.     CvSVM svm = CvSVM();   
  373.     svm.load( "SVM_DATA.xml" );  
  374.   
  375.     const char fileName[] = "../res/t10k-images.idx3-ubyte";  
  376.     const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";  
  377.   
  378.     ifstream lab_ifs(labelFileName, ios_base::binary);  
  379.     ifstream ifs(fileName, ios_base::binary);  
  380.   
  381.     if( ifs.fail() == true )  
  382.         return -1;  
  383.   
  384.     if( lab_ifs.fail() == true )  
  385.         return -1;  
  386.   
  387.     char magicNum[4], ccount[4], crows[4], ccols[4];  
  388.     ifs.read(magicNum, sizeof(magicNum));  
  389.     ifs.read(ccount, sizeof(ccount));  
  390.     ifs.read(crows, sizeof(crows));  
  391.     ifs.read(ccols, sizeof(ccols));  
  392.   
  393.     int count, rows, cols;  
  394.     swapBuffer(ccount);  
  395.     swapBuffer(crows);  
  396.     swapBuffer(ccols);  
  397.   
  398.     memcpy(&count, ccount, sizeof(count));  
  399.     memcpy(&rows, crows, sizeof(rows));  
  400.     memcpy(&cols, ccols, sizeof(cols));  
  401.   
  402.     Mat src = Mat::zeros(rows, cols, CV_8UC1);  
  403.     Mat temp = Mat::zeros(8, 8, CV_8UC1);  
  404.     Mat m = Mat::zeros(1, featureLen, CV_32FC1);  
  405.     Mat img, dst;  
  406.   
  407.     //Just skip label header  
  408.     lab_ifs.read(magicNum, sizeof(magicNum));  
  409.     lab_ifs.read(ccount, sizeof(ccount));  
  410.   
  411.     char label = 0;  
  412.     Scalar templateColor(255, 0, 0);  
  413.   
  414.     NumTrainData rtd;  
  415.   
  416.     int right = 0, error = 0, total = 0;  
  417.     int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;  
  418.     while(ifs.good())  
  419.     {  
  420.         //Read label  
  421.         lab_ifs.read(&label, 1);  
  422.         label = label + '0';  
  423.   
  424.         //Read data  
  425.         ifs.read((char*)src.data, rows * cols);  
  426.         GetROI(src, dst);  
  427.   
  428.         //Too small to watch  
  429.         img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);  
  430.         resize(dst, img, img.size());  
  431.   
  432.         rtd.result = label;  
  433.         resize(dst, temp, temp.size());  
  434.         //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);  
  435.         for(int i = 0; i<8; i++)  
  436.         {  
  437.             for(int j = 0; j<8; j++)  
  438.             {  
  439.                     m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);  
  440.             }  
  441.         }  
  442.   
  443.         if(total >= count)  
  444.             break;  
  445.   
  446.         normalize(m, m);  
  447.         char ret = (char)svm.predict(m);   
  448.   
  449.         if(ret == label)  
  450.         {  
  451.             right++;  
  452.             if(total <= 5000)  
  453.                 right_1++;  
  454.             else  
  455.                 right_2++;  
  456.         }  
  457.         else  
  458.         {  
  459.             error++;  
  460.             if(total <= 5000)  
  461.                 error_1++;  
  462.             else  
  463.                 error_2++;  
  464.         }  
  465.   
  466.         total++;  
  467.   
  468. #if(SHOW_PROCESS)  
  469.         stringstream ss;  
  470.         ss << "Number " << label << ", predict " << ret;  
  471.         string text = ss.str();  
  472.         putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
  473.   
  474.         imshow("img", img);  
  475.         if(waitKey(0)==27) //ESC to quit  
  476.             break;  
  477. #endif  
  478.   
  479.     }  
  480.   
  481.     ifs.close();  
  482.     lab_ifs.close();  
  483.   
  484.     stringstream ss;  
  485.     ss << "Total " << total << ", right " << right <<", error " << error;  
  486.     string text = ss.str();  
  487.     putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
  488.     imshow("img", img);  
  489.     waitKey(0);  
  490.   
  491.     return 0;  
  492. }  
  493.   
  494. int main( int argc, char *argv[] )  
  495. {  
  496. #if(ON_STUDY)  
  497.     int maxCount = 60000;  
  498.     ReadTrainData(maxCount);  
  499.   
  500.     //newRtStudy(buffer);  
  501.     newSvmStudy(buffer);  
  502. #else  
  503.     //newRtPredict();  
  504.     newSvmPredict();  
  505. #endif  
  506.     return 0;  
  507. }  

 





原文地址:http://blog.csdn.net/zhazhiqiang/article/details/20137087

0 0