基于Opencv库中SVM模块的MNIST手写字识别数据库识别

来源:互联网 发布:linux启动停在进度条 编辑:程序博客网 时间:2024/04/29 21:31

基于Opencv库中SVM模块的MNIST手写字识别数据库识别代码。

 

       MNIST的手写数字数据库,有60000例训练集, 10000个测试集。它是更大的数据集NIST的一个子集。 数字已经被size-normalized,是有固定大小的图像。

       官方地址:http://yann.lecun.com/exdb/mnist/

       有对这个数据的详细介绍。这里提一下,数据他是二进制文件格式存储的。不是图片格式,所以需要注意其数据存放格式,在opencv中进行数据格式转换。

数据格式:

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):

[offset]  [type]         [value]            [description]

0000     32 bitinteger   0x00000801(2049)  magic number (MSB first)

0004     32 bitinteger   60000             number of items

0008    unsigned byte   ??                label

........

xxxx    unsigned byte   ??                label

The labels values are 0 to 9.

 

TRAINING SET IMAGE FILE (train-images-idx3-ubyte):

[offset]   [type]         [value]           [description]

0000     32 bitinteger   0x00000803(2051)   magicnumber

0004     32 bitinteger   60000            number of images

0008     32 bitinteger   28               number of rows

0012     32 bitinteger   28               number of columns

0016    unsigned byte   ??               pixel

0017    unsigned byte   ??               pixel

........

xxxx    unsigned byte   ??               pixel



SVM的识别错误率:


界面:


环境:opencv2.4+Ubuntu+linux


nistlabledata.h

#ifndef NISTLABLEDATA_H#define NISTLABLEDATA_H#include <opencv2/opencv.hpp>#include "nisttraindata.h"#include "trainsformdata.h"using namespace std;using namespace cv;class NISTLableData:public trainsformdata{public:    NISTLableData();    ~NISTLableData();private:    long int magic_number;    long int number_of_items;    static const long int magic_number_setted= 0x801;    //friend long int NISTTrainData::trainsform_32bitDataform(long int &data,unsigned char* char_nums);public:    unsigned char magic_numbers[4],number_items[4];    long int getnumber_of_items();    bool check_magic_number();    unsigned char lable;    void trainsform_Dataforms()    {         trainsform_32bitDataform(magic_number,magic_numbers);         trainsform_32bitDataform(number_of_items,number_items);    }    void show_Data()    {        cout<<"magic_number:"<<magic_number<<endl;        cout<<"number_of_items:"<<number_of_items<<endl;    }};#endif // NISTLABLEDATA_H

Nistlabledata.cpp

#include "nistlabledata.h"NISTLableData::NISTLableData(){    magic_number=0;    number_of_items=0;}NISTLableData::~NISTLableData(){}

Nisttraindata.h

#ifndef NISTTRAINDATA_H#define NISTTRAINDATA_H#include <opencv2/opencv.hpp>#include "trainsformdata.h"using namespace std;using namespace cv;class NISTTrainData:public trainsformdata{public:    NISTTrainData();    ~NISTTrainData();private:    long int magic_number;    long int number_of_images;    long int number_of_rows;    long int number_of_columns;    static const long int magic_number_setted= 0x803;public:    static const int image_row= 20;    static const int image_col= 20;    unsigned char magicNum[4], ccount[4], crows[4], ccols[4];    void GetROI(Mat& src, Mat& dst);    friend long int trainsform_32bitDataform(long int &data,unsigned char* char_nums);    long int getnumber_of_images();    long int getrows();    long int getcols();    void trainsform_Dataforms();    void show_Data();    bool check_magic_number();    uchar data[64];};#endif // NISTTRAINDATA_H

Nisttraindata.cpp

#include "nisttraindata.h"#include "trainsformdata.h"//#include "trainsformdata.h"NISTTrainData::NISTTrainData(){    magic_number = 0;    number_of_images = 0;    number_of_rows = 0;    number_of_columns = 0;}NISTTrainData::~NISTTrainData(){}void NISTTrainData::GetROI(Mat& src, Mat& dst){    int left, right, top, bottom;    left = src.cols;    right = 0;    top = src.rows;    bottom = 0;    //Get valid area 遍历图像统计区域端点    for(int i=0; i<src.rows; i++)    {        for(int j=0; j<src.cols; j++)        {            if(src.at<uchar>(i, j) > 0)            {                if(j<left) left = j;                if(j>right) right = j;                if(i<top) top = i;                if(i>bottom) bottom = i;            }        }    }    Point center;    center.x = (left + right) / 2;    center.y = (top + bottom) / 2;    int width = right - left + 1;    int height = bottom - top + 1;    int len = (width < height) ? height : width;    if(width < height)    {        left  = center.x - height*0.5;        right = center.x + height*0.5;    }    else if(width > height)    {        top    = center.y - width*0.5;        bottom = center.y + width*0.5;    }//    cout<<"roi len:"<<len<<endl;    dst.create(len,len,CV_8UC1);    for(int i=0; i<dst.rows; i++)        for(int j=0; j<dst.cols; j++)        {            dst.data[i*dst.cols+j] = src.data[(i+top)*src.cols+j+left];            //dst.at<uchar>(i,j) = src.at<uchar>(i+top,j+left);        }    resize(dst, dst, Size(image_row,image_col));}long int NISTTrainData::getnumber_of_images(){    return number_of_images;}long int NISTTrainData::getrows(){    return number_of_rows;}long int NISTTrainData::getcols(){    return number_of_columns;}void NISTTrainData::trainsform_Dataforms(){    trainsform_32bitDataform(magic_number,magicNum);    trainsform_32bitDataform(number_of_images,ccount);    trainsform_32bitDataform(number_of_rows,crows);    trainsform_32bitDataform(number_of_columns,ccols);}void NISTTrainData::show_Data(){    cout<<" magic_number: "<<magic_number<<          " number_of_images: "<<number_of_images<<          " number_of_rows: "<<number_of_rows<<          " number_of_columns: "<<number_of_columns<<endl;}bool NISTTrainData::check_magic_number(){    return (magic_number==magic_number_setted);}

Trainformdata.h

#ifndef TRAINSFORMDATA_H#define TRAINSFORMDATA_Hclass trainsformdata{public:    trainsformdata();    ~trainsformdata();    long int trainsform_32bitDataform(long int &data,unsigned char* char_nums)    {        data+= (((unsigned long int)char_nums[0])<<24);        data+= (((unsigned long int)char_nums[1])<<16);        data+= (((unsigned long int)char_nums[2])<<8);        data+= ((unsigned long int)char_nums[3]);        return data;    }};#endif // TRAINSFORMDATA_H

Trainsformdata.c

#include "trainsformdata.h"trainsformdata::trainsformdata(){}trainsformdata::~trainsformdata(){}

Mainwindow.cpp

#include "mainwindow.h"#include "ui_mainwindow.h"#include <opencv2/core/core.hpp>#include <opencv2/highgui/highgui.hpp>#include <opencv2/imgproc/imgproc.hpp>#include <opencv2/ml/ml.hpp>#include <opencv2/opencv.hpp>#include "qdebug.h"#include "nisttraindata.h"#include "nistlabledata.h"#include <fstream>#include <vector>using namespace std;using namespace cv;#define NTRAINING_SAMPLES   100         // 每类训练样本的数量#define FRAC_LINEAR_SEP     0.9f        // 线性可分部分的样本组成比例struct InputData{    unsigned char lable;    float data[NISTTrainData::image_row*NISTTrainData::image_col];}InputData_;vector<InputData> buffer;void MainWindow::on_pushButton_2_clicked()//载入数据{    //Open image and label file    NISTTrainData TData;    NISTLableData LData;    const char fileName[] = "../res/train-images.idx3-ubyte";    const char labelFileName[] = "../res/train-labels.idx1-ubyte";    ifstream lab_ifs(labelFileName, ios_base::binary);    ifstream ifs(fileName, ios_base::binary);    if( ifs.fail() == true )    {        cout<<"train fail"<<endl;        return;    }    if( lab_ifs.fail() == true )    {        cout<<"labelFile fail"<<endl;        return;    }    ifs.read((char *)&(TData.magicNum[0]), sizeof(long int));    ifs.read((char *)&(TData.ccount[0]), sizeof(long int));    ifs.read((char *)&(TData.crows[0]), sizeof(long int));    ifs.read((char *)&(TData.ccols[0]), sizeof(long int));    TData.trainsform_Dataforms();    TData.show_Data();    lab_ifs.read((char *)&(LData.magic_numbers),sizeof(long int));    lab_ifs.read((char *)&(LData.number_items),sizeof(long int));    LData.trainsform_Dataforms();    LData.show_Data();    //Just skip label header    //lab_ifs.read(magicNum, sizeof(magicNum));    //lab_ifs.read(ccount, sizeof(ccount));    //Create source and show image matrix    Mat src = Mat::zeros(28, 28, CV_8UC1);    Mat temp = Mat::zeros(8, 8, CV_8UC1);    const int total = 2000;    int count = 0;    Mat roi;    while(!ifs.eof())    {        if(count >= total||count==TData.getnumber_of_images())             break;        count++;        ifs.read((char *)(src.data), TData.getcols()*TData.getrows());        TData.GetROI(src,roi);        lab_ifs.read((char *)(&(LData.lable)),sizeof(char));        //imshow("1",roi);        LData.lable =LData.lable+'0';        cout<<"lable:"<<LData.lable<<endl;        //waitKey(10);        InputData_.lable = LData.lable;        for(int i = 0; i<TData.image_row; i++)        {            for(int j = 0; j<TData.image_col; j++)            {                InputData_.data[ i*TData.image_col +j] = roi.at<uchar>(i, j);            }        }        buffer.push_back(InputData_);    }    cout<<"load trainingdata ok"<<endl;    ifs.close();    lab_ifs.close();    cout<<"123\b456";    cout<<"\b"<<endl;    std::cout<<"hello\b123"<<std::endl;}void MainWindow::on_train_clicked(){    vector<InputData>& trainData = buffer;    int testCount = trainData.size();    int featureLen = NISTTrainData::image_col*NISTTrainData::image_row;    Mat m = Mat::zeros(1, featureLen, CV_32FC1);    Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);    Mat res = Mat::zeros(testCount, 1, CV_32SC1);    for (int i= 0; i< testCount; i++)    {        InputData td = trainData.at(i);        memcpy(m.data, td.data, featureLen*sizeof(float));        normalize(m, m);        memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));        res.at<unsigned int>(i, 0) = td.lable;    }//    Mat showm = Mat::zeros(20, 20, CV_32FC1);//    for(int i =0;i<showm.rows;i++)//        for(int j =0;j<showm.cols;j++)//        {//            showm.at<float>(i,j) = ((InputData)trainData.at(1)).data[i*showm.cols+j];//        }//    imshow("sss",showm);    CvSVM svm = CvSVM();    CvSVMParams param;    CvTermCriteria criteria;    criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);    param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);    //param= CvSVMParams(CvSVM::C_SVC, CvSVM::LINEAR, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);    cout<<"training..."<<endl<<"it takes a long time, please wait!"<<endl;    svm.train(data, res, Mat(), Mat(), param);    cout<<"training finished..."<<endl;    cout<<"saving \"SVM_DATA.xml\"..."<<endl;    svm.save( "SVM_DATA.xml" );    cout<<"saved..."<<endl;    CvSVM svmpredict = CvSVM();    svmpredict.load( "SVM_DATA.xml" );    InputData td = trainData.at(0);    memcpy(m.data, td.data, featureLen*sizeof(float));    normalize(m, m);    char ret = (char)svmpredict.predict(m);    cout<<"ret is :"<<ret<<endl;    cout<<"labble is :"<<td.lable<<endl;}void MainWindow::on_testPredict_clicked(){    vector<InputData> Testbuffer;    NISTTrainData TData;    NISTLableData LData;    const char fileName[] = "../res/t10k-images.idx3-ubyte";    const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";    ifstream lab_ifs(labelFileName, ios_base::binary);    ifstream ifs(fileName, ios_base::binary);    if( ifs.fail() == true )    {        cout<<"train fail"<<endl;        return;    }    if( lab_ifs.fail() == true )    {        cout<<"labelFile fail"<<endl;        return;    }    ifs.read((char *)&(TData.magicNum[0]), sizeof(long int));    ifs.read((char *)&(TData.ccount[0]), sizeof(long int));    ifs.read((char *)&(TData.crows[0]), sizeof(long int));    ifs.read((char *)&(TData.ccols[0]), sizeof(long int));    TData.trainsform_Dataforms();    TData.show_Data();    lab_ifs.read((char *)&(LData.magic_numbers),sizeof(long int));    lab_ifs.read((char *)&(LData.number_items),sizeof(long int));    LData.trainsform_Dataforms();    LData.show_Data();    //Just skip label header    //lab_ifs.read(magicNum, sizeof(magicNum));    //lab_ifs.read(ccount, sizeof(ccount));    //Create source and show image matrix    Mat src = Mat::zeros(28, 28, CV_8UC1);    Mat temp = Mat::zeros(8, 8, CV_8UC1);    int total = 10000;    int count = 0;    Mat roi;    while(!ifs.eof())    {        if(count >= total||count==TData.getnumber_of_images())             break;        count++;        ifs.read((char *)(src.data), TData.getcols()*TData.getrows());        TData.GetROI(src,roi);        lab_ifs.read((char *)(&(LData.lable)),sizeof(char));        //imshow("1",roi);        LData.lable =LData.lable+'0';        cout<<"lable:"<<LData.lable<<endl;        //waitKey(0);        InputData_.lable = LData.lable;        for(int i = 0; i<TData.image_row; i++)        {            for(int j = 0; j<TData.image_col; j++)            {                InputData_.data[ i*TData.image_col +j] = roi.at<uchar>(i, j);            }        }        Testbuffer.push_back(InputData_);    }    vector<InputData>& trainData = Testbuffer;    int testCount = trainData.size();    int featureLen = NISTTrainData::image_col*NISTTrainData::image_row;    Mat m = Mat::zeros(1, featureLen, CV_32FC1);    cout<<"load trainingdata ok"<<endl;    ifs.close();    lab_ifs.close();    CvSVM svmpredict1 = CvSVM();    svmpredict1.load( "SVM_DATA.xml" );    cout<<"testing..."<<endl;    int count_test = 0;    for(int i = 0; i<testCount; i++)    {        InputData td = trainData.at(i);        memcpy(m.data, td.data, featureLen*sizeof(float));        normalize(m, m);        char ret = (char)svmpredict1.predict(m);//        cout<<"ret is :"<<ret<<endl;//        cout<<"labble is :"<<td.lable<<endl;        if(ret == td.lable)        {            count_test++;        }        if(i%(testCount/100) == 0)        {            cout<<i/100<<"%"<<endl;        }    }    cout<<"test finished!"<<endl;    cout<<"crect:"<<(count_test*1.0/testCount)*100<<"%"<<endl;    cout<<"totall:"<<count_test<<endl;//    cout<<"ret is :"<<ret<<endl;//    cout<<"labble is :"<<td.lable<<endl;}

其中一个数据(已经被博主归一化了大小):

导入数据的输出提示:


模型训练提示输出:


测试集测试结果:线性核下正确率92.83%,低于上面网站上的的正确率,可能是参数没有设置好。




0 0