【opencv】神经网络识别字母+数字

来源:互联网 发布:淘宝找朋友代付 编辑:程序博客网 时间:2024/05/08 23:31

继承自本人博客:

http://blog.csdn.net/qq_15947787/article/details/51385861

原文只是识别数字0-9,简单修改后可以识别24个字母(除了I,O)与数字。

把0与O看成一类,1与I看成一类

附件从原文下载即可。

[cpp] view plain copy
  1. //opencv2.4.9 + vs2012 + 64位  
  2. #include <windows.h>  
  3. #include <iostream>  
  4. #include <fstream>  
  5. #include <opencv2/opencv.hpp>  
  6.   
  7. using namespace cv;  
  8. using namespace std;  
  9.   
  10. char* WcharToChar(const wchar_t* wp)    
  11. {    
  12.     char *m_char;  
  13.     int len= WideCharToMultiByte(CP_ACP,0,wp,wcslen(wp),NULL,0,NULL,NULL);    
  14.     m_char=new char[len+1];    
  15.     WideCharToMultiByte(CP_ACP,0,wp,wcslen(wp),m_char,len,NULL,NULL);    
  16.     m_char[len]='\0';    
  17.     return m_char;    
  18. }    
  19.   
  20. wchar_t* CharToWchar(const char* c)    
  21. {     
  22.     wchar_t *m_wchar;  
  23.     int len = MultiByteToWideChar(CP_ACP,0,c,strlen(c),NULL,0);    
  24.     m_wchar=new wchar_t[len+1];    
  25.     MultiByteToWideChar(CP_ACP,0,c,strlen(c),m_wchar,len);    
  26.     m_wchar[len]='\0';    
  27.     return m_wchar;    
  28. }    
  29.   
  30. wchar_t* StringToWchar(const string& s)    
  31. {    
  32.     const char* p=s.c_str();    
  33.     return CharToWchar(p);    
  34. }    
  35.   
  36. int main()  
  37. {  
  38.     const string fileform = "*.png";  
  39.     const string perfileReadPath = "charSamples";  
  40.   
  41.     const int sample_mun_perclass = 20;//训练字符每类数量  
  42.     const int class_mun = 10+26;//训练字符类数 0-9 A-Z 除了I、O  
  43.   
  44.     const int image_cols = 8;  
  45.     const int image_rows = 16;  
  46.     string  fileReadName,  
  47.             fileReadPath;  
  48.     char temp[256];  
  49.   
  50.     float trainingData[class_mun*sample_mun_perclass][image_rows*image_cols] = {{0}};//每一行一个训练样本  
  51.     float labels[class_mun*sample_mun_perclass][class_mun]={{0}};//训练样本标签  
  52.   
  53.     for(int i = 0; i <= class_mun - 1; i++)//不同类  
  54.     {  
  55.         //读取每个类文件夹下所有图像  
  56.         int j = 0;//每一类读取图像个数计数  
  57.   
  58.         if (i <= 9)//0-9  
  59.         {  
  60.             sprintf(temp, "%d", i);  
  61.             //printf("%d\n", i);  
  62.         }  
  63.         else//A-Z  
  64.         {  
  65.             sprintf(temp, "%c", i + 55);  
  66.             //printf("%c\n", i+55);  
  67.         }  
  68.                
  69.         fileReadPath = perfileReadPath + "/" + temp + "/" + fileform;  
  70.         cout<<"文件夹"<<temp<<endl;  
  71.           
  72.         HANDLE hFile;  
  73.         LPCTSTR lpFileName = StringToWchar(fileReadPath);//指定搜索目录和文件类型,如搜索d盘的音频文件可以是"D:\\*.mp3"  
  74.         WIN32_FIND_DATA pNextInfo;  //搜索得到的文件信息将储存在pNextInfo中;  
  75.         hFile = FindFirstFile(lpFileName,&pNextInfo);//请注意是 &pNextInfo , 不是 pNextInfo;  
  76.         if(hFile == INVALID_HANDLE_VALUE)  
  77.         {  
  78.             continue;//搜索失败  
  79.         }  
  80.         //do-while循环读取  
  81.         do  
  82.         {     
  83.             if(pNextInfo.cFileName[0] == '.')//过滤.和..  
  84.                 continue;  
  85.             j++;//读取一张图  
  86.             //wcout<<pNextInfo.cFileName<<endl;  
  87.             //printf("%s\n",WcharToChar(pNextInfo.cFileName));  
  88.             //对读入的图片进行处理  
  89.             Mat srcImage = imread( perfileReadPath + "/" + temp + "/" + WcharToChar(pNextInfo.cFileName),CV_LOAD_IMAGE_GRAYSCALE);  
  90.             Mat resizeImage;  
  91.             Mat trainImage;  
  92.             Mat result;  
  93.   
  94.             resize(srcImage,resizeImage,Size(image_cols,image_rows),(0,0),(0,0),CV_INTER_AREA);//使用象素关系重采样。当图像缩小时候,该方法可以避免波纹出现  
  95.             threshold(resizeImage,trainImage,0,255,CV_THRESH_BINARY|CV_THRESH_OTSU);  
  96.   
  97.             for(int k = 0; k<image_rows*image_cols; ++k)  
  98.             {  
  99.                 trainingData[i*sample_mun_perclass+(j-1)][k] = (float)trainImage.data[k];  
  100.                 //trainingData[i*sample_mun_perclass+(j-1)][k] = (float)trainImage.at<unsigned char>((int)k/8,(int)k%8);//(float)train_image.data[k];  
  101.                 //cout<<trainingData[i*sample_mun_perclass+(j-1)][k] <<" "<< (float)trainImage.at<unsigned char>(k/8,k%8)<<endl;  
  102.             }  
  103.   
  104.         }while (FindNextFile(hFile,&pNextInfo) && j<sample_mun_perclass);//如果设置读入的图片数量,则以设置的为准,如果图片不够,则读取文件夹下所有图片  
  105.       
  106.     }  
  107.   
  108.     // Set up training data Mat  
  109.     Mat trainingDataMat(class_mun*sample_mun_perclass, image_rows*image_cols, CV_32FC1, trainingData);  
  110.     cout<<"trainingDataMat——OK!"<<endl;  
  111.   
  112.     // Set up label data   
  113.     for(int i = 0;i <= class_mun-1; ++i)  
  114.     {  
  115.         for(int j = 0;j <= sample_mun_perclass - 1; ++j)  
  116.         {  
  117.             for(int k = 0;k < class_mun; ++k)  
  118.             {  
  119.                 if(k == i)  
  120.                     if (k == 18)  
  121.                     {  
  122.                         labels[i*sample_mun_perclass + j][1] = 1;  
  123.                     }  
  124.                     else if(k == 24)  
  125.                     {  
  126.                         labels[i*sample_mun_perclass + j][0] = 1;  
  127.                     }  
  128.                     else  
  129.                     {  
  130.                         labels[i*sample_mun_perclass + j][k] = 1;  
  131.                     }        
  132.                 else   
  133.                     labels[i*sample_mun_perclass + j][k] = 0;  
  134.             }  
  135.         }  
  136.     }  
  137.     Mat labelsMat(class_mun*sample_mun_perclass, class_mun, CV_32FC1,labels);  
  138.     cout<<"labelsMat:"<<endl;  
  139.     ofstream outfile("out.txt");  
  140.     outfile<<labelsMat;  
  141.     //cout<<labelsMat<<endl;  
  142.     cout<<"labelsMat——OK!"<<endl;  
  143.   
  144.     //训练代码  
  145.   
  146.     cout<<"training start...."<<endl;  
  147.     CvANN_MLP bp;  
  148.     // Set up BPNetwork's parameters  
  149.     CvANN_MLP_TrainParams params;  
  150.     params.train_method=CvANN_MLP_TrainParams::BACKPROP;  
  151.     params.bp_dw_scale=0.001;  
  152.     params.bp_moment_scale=0.1;  
  153.     params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER|CV_TERMCRIT_EPS,10000,0.0001);  //设置结束条件  
  154.     //params.train_method=CvANN_MLP_TrainParams::RPROP;  
  155.     //params.rp_dw0 = 0.1;  
  156.     //params.rp_dw_plus = 1.2;  
  157.     //params.rp_dw_minus = 0.5;  
  158.     //params.rp_dw_min = FLT_EPSILON;  
  159.     //params.rp_dw_max = 50.;  
  160.   
  161.     //Setup the BPNetwork  
  162.     Mat layerSizes=(Mat_<int>(1,5) << image_rows*image_cols,128,128,128,class_mun);  
  163.     bp.create(layerSizes,CvANN_MLP::SIGMOID_SYM,1.0,1.0);//CvANN_MLP::SIGMOID_SYM  
  164.                                                //CvANN_MLP::GAUSSIAN  
  165.                                                //CvANN_MLP::IDENTITY  
  166.     cout<<"training...."<<endl;  
  167.     bp.train(trainingDataMat, labelsMat, Mat(),Mat(), params);  
  168.   
  169.     bp.save("../bpcharModel.xml"); //save classifier  
  170.     cout<<"training finish...bpModel1.xml saved "<<endl;  
  171.   
  172.   
  173.     //测试神经网络  
  174.     cout<<"测试:"<<endl;  
  175.     Mat test_image = imread("test4.png",CV_LOAD_IMAGE_GRAYSCALE);  
  176.     Mat test_temp;  
  177.     resize(test_image,test_temp,Size(image_cols,image_rows),(0,0),(0,0),CV_INTER_AREA);//使用象素关系重采样。当图像缩小时候,该方法可以避免波纹出现  
  178.     threshold(test_temp,test_temp,0,255,CV_THRESH_BINARY|CV_THRESH_OTSU);  
  179.     Mat_<float>sampleMat(1,image_rows*image_cols);   
  180.     for(int i = 0; i<image_rows*image_cols; ++i)    
  181.     {    
  182.         sampleMat.at<float>(0,i) = (float)test_temp.at<uchar>(i/8,i%8);    
  183.     }    
  184.       
  185.     Mat responseMat;    
  186.     bp.predict(sampleMat,responseMat);    
  187.     Point maxLoc;  
  188.     double maxVal = 0;  
  189.     minMaxLoc(responseMat,NULL,&maxVal,NULL,&maxLoc);  
  190.   
  191.     if (maxLoc.x <= 9)//0-9  
  192.     {  
  193.         sprintf(temp, "%d", maxLoc.x);  
  194.         //printf("%d\n", i);  
  195.     }  
  196.     else//A-Z  
  197.     {  
  198.         sprintf(temp, "%c", maxLoc.x + 55);  
  199.         //printf("%c\n", i+55);  
  200.     }  
  201.   
  202.     cout<<"识别结果:"<<temp<<"    相似度:"<<maxVal*100<<"%"<<endl;  
  203.     imshow("test_image",test_image);    
  204.     waitKey(0);  
  205.       
  206.     return 0;  

原创粉丝点击