基于Opencv库中SVM模块的MNIST手写字识别数据库识别
来源:互联网 发布:linux启动停在进度条 编辑:程序博客网 时间:2024/04/29 21:31
基于Opencv库中SVM模块的MNIST手写字识别数据库识别代码。
MNIST的手写数字数据库,有60000例训练集, 10000个测试集。它是更大的数据集NIST的一个子集。 数字已经被size-normalized,是有固定大小的图像。
官方地址:http://yann.lecun.com/exdb/mnist/
有对这个数据的详细介绍。这里提一下,数据他是二进制文件格式存储的。不是图片格式,所以需要注意其数据存放格式,在opencv中进行数据格式转换。
数据格式:
TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
[offset] [type] [value] [description]
0000 32 bitinteger 0x00000801(2049) magic number (MSB first)
0004 32 bitinteger 60000 number of items
0008 unsigned byte ?? label
........
xxxx unsigned byte ?? label
The labels values are 0 to 9.
TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
[offset] [type] [value] [description]
0000 32 bitinteger 0x00000803(2051) magicnumber
0004 32 bitinteger 60000 number of images
0008 32 bitinteger 28 number of rows
0012 32 bitinteger 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
SVM的识别错误率:
界面:
环境:opencv2.4+Ubuntu+linux
nistlabledata.h
#ifndef NISTLABLEDATA_H#define NISTLABLEDATA_H#include <opencv2/opencv.hpp>#include "nisttraindata.h"#include "trainsformdata.h"using namespace std;using namespace cv;class NISTLableData:public trainsformdata{public: NISTLableData(); ~NISTLableData();private: long int magic_number; long int number_of_items; static const long int magic_number_setted= 0x801; //friend long int NISTTrainData::trainsform_32bitDataform(long int &data,unsigned char* char_nums);public: unsigned char magic_numbers[4],number_items[4]; long int getnumber_of_items(); bool check_magic_number(); unsigned char lable; void trainsform_Dataforms() { trainsform_32bitDataform(magic_number,magic_numbers); trainsform_32bitDataform(number_of_items,number_items); } void show_Data() { cout<<"magic_number:"<<magic_number<<endl; cout<<"number_of_items:"<<number_of_items<<endl; }};#endif // NISTLABLEDATA_H
Nistlabledata.cpp
#include "nistlabledata.h"NISTLableData::NISTLableData(){ magic_number=0; number_of_items=0;}NISTLableData::~NISTLableData(){}
Nisttraindata.h
#ifndef NISTTRAINDATA_H#define NISTTRAINDATA_H#include <opencv2/opencv.hpp>#include "trainsformdata.h"using namespace std;using namespace cv;class NISTTrainData:public trainsformdata{public: NISTTrainData(); ~NISTTrainData();private: long int magic_number; long int number_of_images; long int number_of_rows; long int number_of_columns; static const long int magic_number_setted= 0x803;public: static const int image_row= 20; static const int image_col= 20; unsigned char magicNum[4], ccount[4], crows[4], ccols[4]; void GetROI(Mat& src, Mat& dst); friend long int trainsform_32bitDataform(long int &data,unsigned char* char_nums); long int getnumber_of_images(); long int getrows(); long int getcols(); void trainsform_Dataforms(); void show_Data(); bool check_magic_number(); uchar data[64];};#endif // NISTTRAINDATA_H
Nisttraindata.cpp
#include "nisttraindata.h"#include "trainsformdata.h"//#include "trainsformdata.h"NISTTrainData::NISTTrainData(){ magic_number = 0; number_of_images = 0; number_of_rows = 0; number_of_columns = 0;}NISTTrainData::~NISTTrainData(){}void NISTTrainData::GetROI(Mat& src, Mat& dst){ int left, right, top, bottom; left = src.cols; right = 0; top = src.rows; bottom = 0; //Get valid area 遍历图像统计区域端点 for(int i=0; i<src.rows; i++) { for(int j=0; j<src.cols; j++) { if(src.at<uchar>(i, j) > 0) { if(j<left) left = j; if(j>right) right = j; if(i<top) top = i; if(i>bottom) bottom = i; } } } Point center; center.x = (left + right) / 2; center.y = (top + bottom) / 2; int width = right - left + 1; int height = bottom - top + 1; int len = (width < height) ? height : width; if(width < height) { left = center.x - height*0.5; right = center.x + height*0.5; } else if(width > height) { top = center.y - width*0.5; bottom = center.y + width*0.5; }// cout<<"roi len:"<<len<<endl; dst.create(len,len,CV_8UC1); for(int i=0; i<dst.rows; i++) for(int j=0; j<dst.cols; j++) { dst.data[i*dst.cols+j] = src.data[(i+top)*src.cols+j+left]; //dst.at<uchar>(i,j) = src.at<uchar>(i+top,j+left); } resize(dst, dst, Size(image_row,image_col));}long int NISTTrainData::getnumber_of_images(){ return number_of_images;}long int NISTTrainData::getrows(){ return number_of_rows;}long int NISTTrainData::getcols(){ return number_of_columns;}void NISTTrainData::trainsform_Dataforms(){ trainsform_32bitDataform(magic_number,magicNum); trainsform_32bitDataform(number_of_images,ccount); trainsform_32bitDataform(number_of_rows,crows); trainsform_32bitDataform(number_of_columns,ccols);}void NISTTrainData::show_Data(){ cout<<" magic_number: "<<magic_number<< " number_of_images: "<<number_of_images<< " number_of_rows: "<<number_of_rows<< " number_of_columns: "<<number_of_columns<<endl;}bool NISTTrainData::check_magic_number(){ return (magic_number==magic_number_setted);}
Trainformdata.h
#ifndef TRAINSFORMDATA_H#define TRAINSFORMDATA_Hclass trainsformdata{public: trainsformdata(); ~trainsformdata(); long int trainsform_32bitDataform(long int &data,unsigned char* char_nums) { data+= (((unsigned long int)char_nums[0])<<24); data+= (((unsigned long int)char_nums[1])<<16); data+= (((unsigned long int)char_nums[2])<<8); data+= ((unsigned long int)char_nums[3]); return data; }};#endif // TRAINSFORMDATA_H
Trainsformdata.c
#include "trainsformdata.h"trainsformdata::trainsformdata(){}trainsformdata::~trainsformdata(){}
Mainwindow.cpp
#include "mainwindow.h"#include "ui_mainwindow.h"#include <opencv2/core/core.hpp>#include <opencv2/highgui/highgui.hpp>#include <opencv2/imgproc/imgproc.hpp>#include <opencv2/ml/ml.hpp>#include <opencv2/opencv.hpp>#include "qdebug.h"#include "nisttraindata.h"#include "nistlabledata.h"#include <fstream>#include <vector>using namespace std;using namespace cv;#define NTRAINING_SAMPLES 100 // 每类训练样本的数量#define FRAC_LINEAR_SEP 0.9f // 线性可分部分的样本组成比例struct InputData{ unsigned char lable; float data[NISTTrainData::image_row*NISTTrainData::image_col];}InputData_;vector<InputData> buffer;void MainWindow::on_pushButton_2_clicked()//载入数据{ //Open image and label file NISTTrainData TData; NISTLableData LData; const char fileName[] = "../res/train-images.idx3-ubyte"; const char labelFileName[] = "../res/train-labels.idx1-ubyte"; ifstream lab_ifs(labelFileName, ios_base::binary); ifstream ifs(fileName, ios_base::binary); if( ifs.fail() == true ) { cout<<"train fail"<<endl; return; } if( lab_ifs.fail() == true ) { cout<<"labelFile fail"<<endl; return; } ifs.read((char *)&(TData.magicNum[0]), sizeof(long int)); ifs.read((char *)&(TData.ccount[0]), sizeof(long int)); ifs.read((char *)&(TData.crows[0]), sizeof(long int)); ifs.read((char *)&(TData.ccols[0]), sizeof(long int)); TData.trainsform_Dataforms(); TData.show_Data(); lab_ifs.read((char *)&(LData.magic_numbers),sizeof(long int)); lab_ifs.read((char *)&(LData.number_items),sizeof(long int)); LData.trainsform_Dataforms(); LData.show_Data(); //Just skip label header //lab_ifs.read(magicNum, sizeof(magicNum)); //lab_ifs.read(ccount, sizeof(ccount)); //Create source and show image matrix Mat src = Mat::zeros(28, 28, CV_8UC1); Mat temp = Mat::zeros(8, 8, CV_8UC1); const int total = 2000; int count = 0; Mat roi; while(!ifs.eof()) { if(count >= total||count==TData.getnumber_of_images()) break; count++; ifs.read((char *)(src.data), TData.getcols()*TData.getrows()); TData.GetROI(src,roi); lab_ifs.read((char *)(&(LData.lable)),sizeof(char)); //imshow("1",roi); LData.lable =LData.lable+'0'; cout<<"lable:"<<LData.lable<<endl; //waitKey(10); InputData_.lable = LData.lable; for(int i = 0; i<TData.image_row; i++) { for(int j = 0; j<TData.image_col; j++) { InputData_.data[ i*TData.image_col +j] = roi.at<uchar>(i, j); } } buffer.push_back(InputData_); } cout<<"load trainingdata ok"<<endl; ifs.close(); lab_ifs.close(); cout<<"123\b456"; cout<<"\b"<<endl; std::cout<<"hello\b123"<<std::endl;}void MainWindow::on_train_clicked(){ vector<InputData>& trainData = buffer; int testCount = trainData.size(); int featureLen = NISTTrainData::image_col*NISTTrainData::image_row; Mat m = Mat::zeros(1, featureLen, CV_32FC1); Mat data = Mat::zeros(testCount, featureLen, CV_32FC1); Mat res = Mat::zeros(testCount, 1, CV_32SC1); for (int i= 0; i< testCount; i++) { InputData td = trainData.at(i); memcpy(m.data, td.data, featureLen*sizeof(float)); normalize(m, m); memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float)); res.at<unsigned int>(i, 0) = td.lable; }// Mat showm = Mat::zeros(20, 20, CV_32FC1);// for(int i =0;i<showm.rows;i++)// for(int j =0;j<showm.cols;j++)// {// showm.at<float>(i,j) = ((InputData)trainData.at(1)).data[i*showm.cols+j];// }// imshow("sss",showm); CvSVM svm = CvSVM(); CvSVMParams param; CvTermCriteria criteria; criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON); param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria); //param= CvSVMParams(CvSVM::C_SVC, CvSVM::LINEAR, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria); cout<<"training..."<<endl<<"it takes a long time, please wait!"<<endl; svm.train(data, res, Mat(), Mat(), param); cout<<"training finished..."<<endl; cout<<"saving \"SVM_DATA.xml\"..."<<endl; svm.save( "SVM_DATA.xml" ); cout<<"saved..."<<endl; CvSVM svmpredict = CvSVM(); svmpredict.load( "SVM_DATA.xml" ); InputData td = trainData.at(0); memcpy(m.data, td.data, featureLen*sizeof(float)); normalize(m, m); char ret = (char)svmpredict.predict(m); cout<<"ret is :"<<ret<<endl; cout<<"labble is :"<<td.lable<<endl;}void MainWindow::on_testPredict_clicked(){ vector<InputData> Testbuffer; NISTTrainData TData; NISTLableData LData; const char fileName[] = "../res/t10k-images.idx3-ubyte"; const char labelFileName[] = "../res/t10k-labels.idx1-ubyte"; ifstream lab_ifs(labelFileName, ios_base::binary); ifstream ifs(fileName, ios_base::binary); if( ifs.fail() == true ) { cout<<"train fail"<<endl; return; } if( lab_ifs.fail() == true ) { cout<<"labelFile fail"<<endl; return; } ifs.read((char *)&(TData.magicNum[0]), sizeof(long int)); ifs.read((char *)&(TData.ccount[0]), sizeof(long int)); ifs.read((char *)&(TData.crows[0]), sizeof(long int)); ifs.read((char *)&(TData.ccols[0]), sizeof(long int)); TData.trainsform_Dataforms(); TData.show_Data(); lab_ifs.read((char *)&(LData.magic_numbers),sizeof(long int)); lab_ifs.read((char *)&(LData.number_items),sizeof(long int)); LData.trainsform_Dataforms(); LData.show_Data(); //Just skip label header //lab_ifs.read(magicNum, sizeof(magicNum)); //lab_ifs.read(ccount, sizeof(ccount)); //Create source and show image matrix Mat src = Mat::zeros(28, 28, CV_8UC1); Mat temp = Mat::zeros(8, 8, CV_8UC1); int total = 10000; int count = 0; Mat roi; while(!ifs.eof()) { if(count >= total||count==TData.getnumber_of_images()) break; count++; ifs.read((char *)(src.data), TData.getcols()*TData.getrows()); TData.GetROI(src,roi); lab_ifs.read((char *)(&(LData.lable)),sizeof(char)); //imshow("1",roi); LData.lable =LData.lable+'0'; cout<<"lable:"<<LData.lable<<endl; //waitKey(0); InputData_.lable = LData.lable; for(int i = 0; i<TData.image_row; i++) { for(int j = 0; j<TData.image_col; j++) { InputData_.data[ i*TData.image_col +j] = roi.at<uchar>(i, j); } } Testbuffer.push_back(InputData_); } vector<InputData>& trainData = Testbuffer; int testCount = trainData.size(); int featureLen = NISTTrainData::image_col*NISTTrainData::image_row; Mat m = Mat::zeros(1, featureLen, CV_32FC1); cout<<"load trainingdata ok"<<endl; ifs.close(); lab_ifs.close(); CvSVM svmpredict1 = CvSVM(); svmpredict1.load( "SVM_DATA.xml" ); cout<<"testing..."<<endl; int count_test = 0; for(int i = 0; i<testCount; i++) { InputData td = trainData.at(i); memcpy(m.data, td.data, featureLen*sizeof(float)); normalize(m, m); char ret = (char)svmpredict1.predict(m);// cout<<"ret is :"<<ret<<endl;// cout<<"labble is :"<<td.lable<<endl; if(ret == td.lable) { count_test++; } if(i%(testCount/100) == 0) { cout<<i/100<<"%"<<endl; } } cout<<"test finished!"<<endl; cout<<"crect:"<<(count_test*1.0/testCount)*100<<"%"<<endl; cout<<"totall:"<<count_test<<endl;// cout<<"ret is :"<<ret<<endl;// cout<<"labble is :"<<td.lable<<endl;}
其中一个数据(已经被博主归一化了大小):
导入数据的输出提示:
模型训练提示输出:
测试集测试结果:线性核下正确率92.83%,低于上面网站上的的正确率,可能是参数没有设置好。
- 基于Opencv库中SVM模块的MNIST手写字识别数据库识别
- 基于tensorflow的MNIST手写字识别
- MNIST手写字识别的TensorFlow实现
- Tensorflow | MNIST手写字识别
- 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型
- 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型
- MXNet | 手写字MNIST识别比赛
- tensorflow mnist数据集手写字识别
- 我的第一个svm程序:手写字识别
- TensorFlow学习笔记(一)MNIST手写字识别
- caffe学习例子(一) mnist手写字识别
- 手写字识别C++
- OpenCV的svm手写字检测
- Python神经网络代码识别手写字的实现流程(一):加载mnist数据
- opencv 基于SVM的几何形状识别
- K-近邻:手写字识别
- 机器学习笔记2-基于KNN算法的手写字识别程序
- 基于 OpenCV 的 LBP + SVM 人脸识别
- github集合
- centos简介与VMware安装
- Webstorm 10 for mac osx 注册机,序列号,kegen
- 接口测试集合
- 三色球
- 基于Opencv库中SVM模块的MNIST手写字识别数据库识别
- Evbuffers:IO缓冲的实用功能
- UIViewAnimation 动画
- pat1012(逆向计数排序的应用)
- js数组常用的方法
- 转发和重定向 方法的使用
- 黑马程序员-OC回顾-面向对象
- servlet与tomcat
- 如果你身边有程序员,请善待他们。