OpenCV的SVM手写数字检测

来源:互联网 发布:js resize事件 编辑:程序博客网 时间:2024/04/28 14:04

转自http://blog.csdn.net/firefight/article/details/6452188

是MNIST手写数字图片库:http://code.google.com/p/supplement-of-the-mnist-database-of-handwritten-digits/downloads/list

其他方法:http://blog.csdn.net/onezeros/article/details/5672192

 

 

使用OPENCV训练手写数字识别分类器 

1,下载训练数据和测试数据文件,这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个
2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端
3,确定字符特征方式为最简单的8×8网格内的字符点数


4,创建SVM,训练并读取,结果如下
 1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
  10000个训练样本,测试数据正确率95.45%
  60000个训练样本,测试数据正确率97.67%

5,编写手写输入的GUI程序,并进行验证,效果还可以接受。

 

以下为主要代码,以供参考

(类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)

#include "stdafx.h"#include <fstream>#include "opencv2/opencv.hpp"#include <vector>using namespace std;using namespace cv;#define SHOW_PROCESS 0#define ON_STUDY 0class NumTrainData{public:NumTrainData(){memset(data, 0, sizeof(data));result = -1;}public:float data[64];int result;};vector<NumTrainData> buffer;int featureLen = 64;void swapBuffer(char* buf){char temp;temp = *(buf);*buf = *(buf+3);*(buf+3) = temp;temp = *(buf+1);*(buf+1) = *(buf+2);*(buf+2) = temp;}void GetROI(Mat& src, Mat& dst){int left, right, top, bottom;left = src.cols;right = 0;top = src.rows;bottom = 0;//Get valid areafor(int i=0; i<src.rows; i++){for(int j=0; j<src.cols; j++){if(src.at<uchar>(i, j) > 0){if(j<left) left = j;if(j>right) right = j;if(i<top) top = i;if(i>bottom) bottom = i;}}}//Point center;//center.x = (left + right) / 2;//center.y = (top + bottom) / 2;int width = right - left;int height = bottom - top;int len = (width < height) ? height : width;//Create a squredst = Mat::zeros(len, len, CV_8UC1);//Copy valid data to squre centerRect dstRect((len - width)/2, (len - height)/2, width, height);Rect srcRect(left, top, width, height);Mat dstROI = dst(dstRect);Mat srcROI = src(srcRect);srcROI.copyTo(dstROI);}int ReadTrainData(int maxCount){//Open image and label fileconst char fileName[] = "../res/train-images.idx3-ubyte";const char labelFileName[] = "../res/train-labels.idx1-ubyte";ifstream lab_ifs(labelFileName, ios_base::binary);ifstream ifs(fileName, ios_base::binary);if( ifs.fail() == true )return -1;if( lab_ifs.fail() == true )return -1;//Read train data number and image rows / colschar magicNum[4], ccount[4], crows[4], ccols[4];ifs.read(magicNum, sizeof(magicNum));ifs.read(ccount, sizeof(ccount));ifs.read(crows, sizeof(crows));ifs.read(ccols, sizeof(ccols));int count, rows, cols;swapBuffer(ccount);swapBuffer(crows);swapBuffer(ccols);memcpy(&count, ccount, sizeof(count));memcpy(&rows, crows, sizeof(rows));memcpy(&cols, ccols, sizeof(cols));//Just skip label headerlab_ifs.read(magicNum, sizeof(magicNum));lab_ifs.read(ccount, sizeof(ccount));//Create source and show image matrixMat src = Mat::zeros(rows, cols, CV_8UC1);Mat temp = Mat::zeros(8, 8, CV_8UC1);Mat img, dst;char label = 0;Scalar templateColor(255, 0, 255 );NumTrainData rtd;//int loop = 1000;int total = 0;while(!ifs.eof()){if(total >= count)break;total++;cout << total << endl;//Read labellab_ifs.read(&label, 1);label = label + '0';//Read source dataifs.read((char*)src.data, rows * cols);GetROI(src, dst);#if(SHOW_PROCESS)//Too small to watchimg = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1);resize(dst, img, img.size());stringstream ss;ss << "Number " << label;string text = ss.str();putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);//imshow("img", img);#endifrtd.result = label;resize(dst, temp, temp.size());//threshold(temp, temp, 10, 1, CV_THRESH_BINARY);for(int i = 0; i<8; i++){for(int j = 0; j<8; j++){rtd.data[ i*8 + j] = temp.at<uchar>(i, j);}}buffer.push_back(rtd);//if(waitKey(0)==27) //ESC to quit//break;maxCount--;if(maxCount == 0)break;}ifs.close();lab_ifs.close();return 0;}void newRtStudy(vector<NumTrainData>& trainData){int testCount = trainData.size();Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);Mat res = Mat::zeros(testCount, 1, CV_32SC1);for (int i= 0; i< testCount; i++) { NumTrainData td = trainData.at(i);memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float));res.at<unsigned int>(i, 0) = td.result;}/////////////START RT TRAINNING//////////////////    CvRTrees forest;    CvMat* var_importance = 0;    forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(),            CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));forest.save( "new_rtrees.xml" );}int newRtPredict(){    CvRTrees forest;forest.load( "new_rtrees.xml" );const char fileName[] = "../res/t10k-images.idx3-ubyte";const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";ifstream lab_ifs(labelFileName, ios_base::binary);ifstream ifs(fileName, ios_base::binary);if( ifs.fail() == true )return -1;if( lab_ifs.fail() == true )return -1;char magicNum[4], ccount[4], crows[4], ccols[4];ifs.read(magicNum, sizeof(magicNum));ifs.read(ccount, sizeof(ccount));ifs.read(crows, sizeof(crows));ifs.read(ccols, sizeof(ccols));int count, rows, cols;swapBuffer(ccount);swapBuffer(crows);swapBuffer(ccols);memcpy(&count, ccount, sizeof(count));memcpy(&rows, crows, sizeof(rows));memcpy(&cols, ccols, sizeof(cols));Mat src = Mat::zeros(rows, cols, CV_8UC1);Mat temp = Mat::zeros(8, 8, CV_8UC1);Mat m = Mat::zeros(1, featureLen, CV_32FC1);Mat img, dst;//Just skip label headerlab_ifs.read(magicNum, sizeof(magicNum));lab_ifs.read(ccount, sizeof(ccount));char label = 0;Scalar templateColor(255, 0, 0);NumTrainData rtd;int right = 0, error = 0, total = 0;int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;while(ifs.good()){//Read labellab_ifs.read(&label, 1);label = label + '0';//Read dataifs.read((char*)src.data, rows * cols);GetROI(src, dst);//Too small to watchimg = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);resize(dst, img, img.size());rtd.result = label;resize(dst, temp, temp.size());//threshold(temp, temp, 10, 1, CV_THRESH_BINARY);for(int i = 0; i<8; i++){for(int j = 0; j<8; j++){m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);}}if(total >= count)break;char ret = (char)forest.predict(m); if(ret == label){right++;if(total <= 5000)right_1++;elseright_2++;}else{error++;if(total <= 5000)error_1++;elseerror_2++;}total++;#if(SHOW_PROCESS)stringstream ss;ss << "Number " << label << ", predict " << ret;string text = ss.str();putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);imshow("img", img);if(waitKey(0)==27) //ESC to quitbreak;#endif}ifs.close();lab_ifs.close();stringstream ss;ss << "Total " << total << ", right " << right <<", error " << error;string text = ss.str();putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);imshow("img", img);waitKey(0);return 0;}void newSvmStudy(vector<NumTrainData>& trainData){int testCount = trainData.size();Mat m = Mat::zeros(1, featureLen, CV_32FC1);Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);Mat res = Mat::zeros(testCount, 1, CV_32SC1);for (int i= 0; i< testCount; i++) { NumTrainData td = trainData.at(i);memcpy(m.data, td.data, featureLen*sizeof(float));normalize(m, m);memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));res.at<unsigned int>(i, 0) = td.result;}/////////////START SVM TRAINNING//////////////////CvSVM svm = CvSVM(); CvSVMParams param; CvTermCriteria criteria;criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON); param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria); svm.train(data, res, Mat(), Mat(), param);svm.save( "SVM_DATA.xml" );}int newSvmPredict(){CvSVM svm = CvSVM(); svm.load( "SVM_DATA.xml" );const char fileName[] = "../res/t10k-images.idx3-ubyte";const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";ifstream lab_ifs(labelFileName, ios_base::binary);ifstream ifs(fileName, ios_base::binary);if( ifs.fail() == true )return -1;if( lab_ifs.fail() == true )return -1;char magicNum[4], ccount[4], crows[4], ccols[4];ifs.read(magicNum, sizeof(magicNum));ifs.read(ccount, sizeof(ccount));ifs.read(crows, sizeof(crows));ifs.read(ccols, sizeof(ccols));int count, rows, cols;swapBuffer(ccount);swapBuffer(crows);swapBuffer(ccols);memcpy(&count, ccount, sizeof(count));memcpy(&rows, crows, sizeof(rows));memcpy(&cols, ccols, sizeof(cols));Mat src = Mat::zeros(rows, cols, CV_8UC1);Mat temp = Mat::zeros(8, 8, CV_8UC1);Mat m = Mat::zeros(1, featureLen, CV_32FC1);Mat img, dst;//Just skip label headerlab_ifs.read(magicNum, sizeof(magicNum));lab_ifs.read(ccount, sizeof(ccount));char label = 0;Scalar templateColor(255, 0, 0);NumTrainData rtd;int right = 0, error = 0, total = 0;int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;while(ifs.good()){//Read labellab_ifs.read(&label, 1);label = label + '0';//Read dataifs.read((char*)src.data, rows * cols);GetROI(src, dst);//Too small to watchimg = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);resize(dst, img, img.size());rtd.result = label;resize(dst, temp, temp.size());//threshold(temp, temp, 10, 1, CV_THRESH_BINARY);for(int i = 0; i<8; i++){for(int j = 0; j<8; j++){m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);}}if(total >= count)break;normalize(m, m);char ret = (char)svm.predict(m); if(ret == label){right++;if(total <= 5000)right_1++;elseright_2++;}else{error++;if(total <= 5000)error_1++;elseerror_2++;}total++;#if(SHOW_PROCESS)stringstream ss;ss << "Number " << label << ", predict " << ret;string text = ss.str();putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);imshow("img", img);if(waitKey(0)==27) //ESC to quitbreak;#endif}ifs.close();lab_ifs.close();stringstream ss;ss << "Total " << total << ", right " << right <<", error " << error;string text = ss.str();putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);imshow("img", img);waitKey(0);return 0;}int main( int argc, char *argv[] ){#if(ON_STUDY)int maxCount = 60000;ReadTrainData(maxCount);//newRtStudy(buffer);newSvmStudy(buffer);#else//newRtPredict();newSvmPredict();#endifreturn 0;}


原创粉丝点击