HOG特征的SVM分类器训练代码

来源:互联网 发布:淘宝开店信誉 编辑:程序博客网 时间:2024/06/05 10:16
[cpp] view plaincopy
  1. #include <cv.h>  
  2. #include <io.h>  
  3. #include <iostream>  
  4. #include <string>  
  5. #include <highgui.h>  
  6. #include <ml.h>  
  7. using namespace std;  
  8. class Mysvm: public CvSVM  
  9. {  
  10. public:  
  11.     int get_alpha_count()  
  12.     {  
  13.         return this->sv_total;  
  14.     }  
  15.   
  16.     int get_sv_dim()  
  17.     {  
  18.         return this->var_all;  
  19.     }  
  20.   
  21.     int get_sv_count()  
  22.     {  
  23.         return this->decision_func->sv_count;  
  24.     }  
  25.   
  26.     double* get_alpha()  
  27.     {  
  28.         return this->decision_func->alpha;  
  29.     }  
  30.   
  31.     float** get_sv()  
  32.     {  
  33.         return this->sv;  
  34.     }  
  35.   
  36.     float get_rho()  
  37.     {  
  38.         return this->decision_func->rho;  
  39.     }  
  40. };  
  41.   
  42. void Train()  
  43. {  
  44.     char classifierSavePath[256] = "E:/work/INRIAPerson/pedestrianDetect-peopleFlow.txt";  
  45.   
  46.     string positivePath = "E:\\work\\INRIAPerson\\posgray\\";  
  47.     string negativePath = "E:\\work\\INRIAPerson\\neggray\\";  
  48.   
  49.     int positiveSampleCount = 2416;  
  50.     int negativeSampleCount = 3654;  
  51.     int totalSampleCount = positiveSampleCount + negativeSampleCount;  
  52.   
  53.     cout<<"//////////////////////////////////////////////////////////////////"<<endl;  
  54.     cout<<"totalSampleCount: "<<totalSampleCount<<endl;  
  55.     cout<<"positiveSampleCount: "<<positiveSampleCount<<endl;  
  56.     cout<<"negativeSampleCount: "<<negativeSampleCount<<endl;  
  57.   
  58.     CvMat *sampleFeaturesMat = cvCreateMat(totalSampleCount , 3780, CV_32FC1);  
  59.     //64*128的训练样本,该矩阵将是totalSample*3780,64*64的训练样本,该矩阵将是totalSample*1764  
  60.     cvSetZero(sampleFeaturesMat);    
  61.     CvMat *sampleLabelMat = cvCreateMat(totalSampleCount, 1, CV_32FC1);//样本标识    
  62.     cvSetZero(sampleLabelMat);    
  63.   
  64.     cout<<"************************************************************"<<endl;  
  65.     cout<<"start to training positive samples..."<<endl;  
  66.   
  67.     char positiveImgName[256];  
  68.     string path;  
  69.     for(int i=0; i<positiveSampleCount; i++)    
  70.     {    
  71.         memset(positiveImgName, '\0', 256*sizeof(char));  
  72.         sprintf(positiveImgName, "%d.png", i+1);  
  73.         int len = strlen(positiveImgName);  
  74.         string tempStr = positiveImgName;  
  75.         path = positivePath + tempStr;  
  76.   
  77.         cv::Mat img = cv::imread(path);  
  78.         if( img.data == NULL )  
  79.         {  
  80.             cout<<"positive image sample load error: "<<i<<" "<<path<<endl;  
  81.             system("pause");  
  82.             continue;  
  83.         }  
  84.   
  85.         cv::HOGDescriptor hog(cv::Size(64,128), cv::Size(16,16), cv::Size(8,8), cv::Size(8,8), 9);  
  86.         vector<float> featureVec;   
  87.   
  88.         hog.compute(img, featureVec, cv::Size(8,8));    
  89.         int featureVecSize = featureVec.size();  
  90.   
  91.         for (int j=0; j<featureVecSize; j++)    
  92.         {         
  93.             CV_MAT_ELEM( *sampleFeaturesMat, float, i, j ) = featureVec[j];  
  94.         }    
  95.         sampleLabelMat->data.fl[i] = 1;  
  96.     }  
  97.     cout<<"end of training for positive samples..."<<endl;  
  98.   
  99.     cout<<"*********************************************************"<<endl;  
  100.     cout<<"start to train negative samples..."<<endl;  
  101.   
  102.     char negativeImgName[256];  
  103.     for (int i=0; i<negativeSampleCount; i++)  
  104.     {    
  105.         memset(negativeImgName, '\0', 256*sizeof(char));  
  106.         sprintf(negativeImgName, "%d.png", i+1);  
  107.         path = negativePath + negativeImgName;  
  108.         cv::Mat img = cv::imread(path);  
  109.         if(img.data == NULL)  
  110.         {  
  111.             cout<<"negative image sample load error: "<<path<<endl;  
  112.             continue;  
  113.         }  
  114.   
  115.         cv::HOGDescriptor hog(cv::Size(64,128), cv::Size(16,16), cv::Size(8,8), cv::Size(8,8), 9);    
  116.         vector<float> featureVec;   
  117.   
  118.         hog.compute(img,featureVec,cv::Size(8,8));//计算HOG特征  
  119.         int featureVecSize = featureVec.size();    
  120.   
  121.         for ( int j=0; j<featureVecSize; j ++)    
  122.         {    
  123.             CV_MAT_ELEM( *sampleFeaturesMat, float, i + positiveSampleCount, j ) = featureVec[ j ];  
  124.         }    
  125.   
  126.         sampleLabelMat->data.fl[ i + positiveSampleCount ] = -1;  
  127.     }    
  128.   
  129.     cout<<"end of training for negative samples..."<<endl;  
  130.     cout<<"********************************************************"<<endl;  
  131.     cout<<"start to train for SVM classifier..."<<endl;  
  132.   
  133.     CvSVMParams params;    
  134.     params.svm_type = CvSVM::C_SVC;    
  135.     params.kernel_type = CvSVM::LINEAR;    
  136.     params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, FLT_EPSILON);  
  137.     params.C = 0.01;  
  138.   
  139.     Mysvm svm;  
  140.     svm.train( sampleFeaturesMat, sampleLabelMat, NULL, NULL, params ); //用SVM线性分类器训练  
  141.     svm.save(classifierSavePath);  
  142.   
  143.     cvReleaseMat(&sampleFeaturesMat);  
  144.     cvReleaseMat(&sampleLabelMat);  
  145.   
  146.     int supportVectorSize = svm.get_support_vector_count();  
  147.     cout<<"support vector size of SVM:"<<supportVectorSize<<endl;  
  148.     cout<<"************************ end of training for SVM ******************"<<endl;  
  149.   
  150.     CvMat *sv,*alp,*re;//所有样本特征向量   
  151.     sv  = cvCreateMat(supportVectorSize , 3780, CV_32FC1);  
  152.     alp = cvCreateMat(1 , supportVectorSize, CV_32FC1);  
  153.     re  = cvCreateMat(1 , 3780, CV_32FC1);  
  154.     CvMat *res  = cvCreateMat(1 , 1, CV_32FC1);  
  155.   
  156.     cvSetZero(sv);  
  157.     cvSetZero(re);  
  158.     
  159.     for(int i=0; i<supportVectorSize; i++)  
  160.     {  
  161.         memcpy( (float*)(sv->data.fl+i*3780), svm.get_support_vector(i), 3780*sizeof(float));      
  162.     }  
  163.   
  164.     double* alphaArr = svm.get_alpha();  
  165.     int alphaCount = svm.get_alpha_count();  
  166.   
  167.     for(int i=0; i<supportVectorSize; i++)  
  168.     {  
  169.         alp->data.fl[i] = alphaArr[i];  
  170.     }  
  171.     cvMatMul(alp, sv, re);  
  172.   
  173.     int posCount = 0;  
  174.     for (int i=0; i<3780; i++)  
  175.     {  
  176.         re->data.fl[i] *= -1;  
  177.     }  
  178.   
  179.     FILE* fp = fopen("E:/work/INRIAPerson/hogSVMDetector-peopleFlow1.txt","wb");  
  180.     if( NULL == fp )  
  181.     {  
  182.         return;  
  183.     }  
  184.     for(int i=0; i<3780; i++)  
  185.     {  
  186.         fprintf(fp,"%f \n",re->data.fl[i]);  
  187.     }  
  188.     float rho = svm.get_rho();  
  189.     fprintf(fp, "%f", rho);  
  190.     cout<<"E:/work/INRIAPerson/hogSVMDetector.txt 保存完毕"<<endl;//保存HOG能识别的分类器  
  191.     fclose(fp);  
  192.   
  193.     return;  
  194. }  
  195.   
  196. int main()  
  197. {  
  198.     Train();  
  199.     return 0;  
  200. }