opencv haartraining 分析三:icvC…

来源:互联网 发布:centos 解压 编辑:程序博客网 时间:2024/06/03 20:37

CvIntHaarClassifier* icvCreateCARTStageClassifier(CvHaarTrainingData* data,
                                                  CvMat* sampleIdx,
                                                  CvIntHaarFeatures* haarFeatures,
                                                  float minhitrate,
                                                  float maxfalsealarm,
                                                  int   symmetric,
                                                  float weightfraction,
                                                  int numsplits,
                                                  CvBoostType boosttype,
                                                  CvStumpError stumperror,
                                                  int maxsplits )

{

#ifdef CV_COL_ARRANGEMENT
    int flags =CV_COL_SAMPLE;
#else
    int flags =CV_ROW_SAMPLE;
#endif

   CvStageHaarClassifier* stage = NULL;
   CvBoostTrainer* trainer;
   CvCARTClassifier* cart = NULL;
   CvCARTTrainParams trainParams;
   CvMTStumpTrainParams stumpTrainParams;
    //CvMat*trainData = NULL;
    //CvMat*sortedIdx = NULL;
    CvMateval;
    int n =0;
    int m =0;
    int numpos =0;
    int numneg =0;
    int numfalse= 0;
    floatsum_stage = 0.0F;
    floatthreshold = 0.0F;
    floatfalsealarm = 0.0F;
   
    //CvMat*sampleIdx = NULL;
    CvMat*trimmedIdx;
    //float*idxdata = NULL;
    //float*tempweights = NULL;
   //int   idxcount = 0;
    CvUserdatauserdata;

    int i =0;
    int j =0;
    intidx;
    intnumsamples;
    intnumtrimmed;
   
   CvCARTHaarClassifier* classifier;
    CvSeq* seq =NULL;
   CvMemStorage* storage = NULL;
    CvMat*weakTrainVals;
    floatalpha;
    floatsumalpha;
    intnum_splits;

#ifdef CV_VERBOSE
    printf("+----+----+-+---------+---------+---------+---------+\n" );
    printf("|  N |%%SMP|F|  ST.THR|   HR    FA   | EXP. ERR|\n" );
    printf("+----+----+-+---------+---------+---------+---------+\n" );
#endif
   
    n =haarFeatures->count;//这是haar特征的数目,对于32*32的子窗口,特征数目为26万多
    m =data->sum.rows;
    numsamples =(sampleIdx) ? MAX( sampleIdx->rows,sampleIdx->cols ) : m;

    userdata= cvUserdata( data, haarFeatures );

   stumpTrainParams.type = ( boosttype == CV_DABCLASS )
       ? CV_CLASSIFICATION_CLASS : CV_REGRESSION;
   stumpTrainParams.error = ( boosttype == CV_LBCLASS || boosttype ==CV_GABCLASS )
       ? CV_SQUARE : stumperror;
   stumpTrainParams.portion = CV_STUMP_TRAIN_PORTION;
   stumpTrainParams.getTrainData = icvGetTrainingDataCallback;
   stumpTrainParams.numcomp = n;
   stumpTrainParams.userdata = &userdata;
   stumpTrainParams.sortedIdx =data->idxcache;//这是对构建cart的每个节点的stump一级决策树参数的设置

   trainParams.count = numsplits;
   trainParams.stumpTrainParams = (CvClassifierTrainParams*)&stumpTrainParams;
   trainParams.stumpConstructor = cvCreateMTStumpClassifier;
   trainParams.splitIdx = icvSplitIndicesCallback;
   trainParams.userdata =&userdata;//这是对cart弱分类器参数的设置

    eval =cvMat( 1, m, CV_32FC1, cvAlloc( sizeof( float ) * m ) );
   
    storage =cvCreateMemStorage();
    seq =cvCreateSeq( 0, sizeof( *seq ), sizeof( classifier ), storage);

   weakTrainVals = cvCreateMat( 1, m, CV_32FC1 );
    trainer =cvBoostStartTraining( &data->cls,weakTrainVals, &data->weights,
                                   sampleIdx, boosttype);//这是用data->cls来计算weakTrainVals。其中weakTrainVals=2*cls-1,cls属于{0,1},则weakTrainVals属于{-1,1}
    num_splits =0;
    sumalpha =0.0F;
    do
      

#ifdef CV_VERBOSE
       int v_wt = 0;
       int v_flipped = 0;
#endif

       trimmedIdx = cvTrimWeights(&data->weights, sampleIdx,weightfraction );//剔除小权值,由weightfraction来控制。
       numtrimmed = (trimmedIdx) ? MAX( trimmedIdx->rows,trimmedIdx->cols ) : m;

#ifdef CV_VERBOSE
       v_wt = 100 * numtrimmed / numsamples;
       v_flipped = 0;

#endif

       cart = (CvCARTClassifier*) cvCreateCARTClassifier(data->valcache,
                       flags,
                       weakTrainVals, 0, 0, 0, trimmedIdx,
                       &(data->weights),
                       (CvClassifierTrainParams*) &trainParams);//开始构建cart树弱分类器

       classifier = (CvCARTHaarClassifier*) icvCreateCARTHaarClassifier(numsplits );
       icvInitCARTHaarClassifier( classifier, cart, haarFeatures );

       num_splits += classifier->count;

       cart->release( (CvClassifier**)&cart );
       
       if( symmetric &&(seq->total % 2) )
       {
           float normfactor = 0.0F;
           CvStumpClassifier* stump;
           
           
           for( i = 0; i < classifier->count;i++ )
           {
               if( classifier->feature[i].desc[0] == 'h' )
               {
                   for( j = 0; j < CV_HAAR_FEATURE_MAX&&
                                   classifier->feature[i].rect[j].weight != 0.0F; j++)
                   {
                       classifier->feature[i].rect[j].r.x =data->winsize.width -
                           classifier->feature[i].rect[j].r.x -
                           classifier->feature[i].rect[j].r.width;               
                   }
               }
               else
               {
                   int tmp = 0;

                   
                   
                   for( j = 0; j < CV_HAAR_FEATURE_MAX&&
                                   classifier->feature[i].rect[j].weight != 0.0F; j++)
                   {
                       classifier->feature[i].rect[j].r.x =data->winsize.width -
                           classifier->feature[i].rect[j].r.x;
                       CV_SWAP(classifier->feature[i].rect[j].r.width,
                                classifier->feature[i].rect[j].r.height, tmp);
                   }
               }
           }
           icvConvertToFastHaarFeature(classifier->feature,
                                        classifier->fastfeature,
                                        classifier->count,data->winsize.width + 1 );

           stumpTrainParams.getTrainData = NULL;
           stumpTrainParams.numcomp = 1;
           stumpTrainParams.userdata = NULL;
           stumpTrainParams.sortedIdx = NULL;

           for( i = 0; i < classifier->count;i++ )
           {
               for( j = 0; j < numtrimmed; j++ )
               {
                   idx = icvGetIdxAt( trimmedIdx, j );

                   eval.data.fl[idx] = cvEvalFastHaarFeature(&classifier->fastfeature[i],
                       (sum_type*) (data->sum.data.ptr + idx *data->sum.step),
                       (sum_type*) (data->tilted.data.ptr + idx *data->tilted.step) );
                   normfactor = data->normfactor.data.fl[idx];
                   eval.data.fl[idx] = ( normfactor == 0.0F )
                       ? 0.0F : (eval.data.fl[idx] / normfactor);
               }

               stump = (CvStumpClassifier*) trainParams.stumpConstructor(&eval,
                   CV_COL_SAMPLE,
                   weakTrainVals, 0, 0, 0, trimmedIdx,
                   &(data->weights),
                   trainParams.stumpTrainParams );
           
               classifier->threshold[i] =stump->threshold;
               if( classifier->left[i] <= 0 )
               {
                   classifier->val[-classifier->left[i]]= stump->left;
               }
               if( classifier->right[i] <= 0 )
               {
                   classifier->val[-classifier->right[i]]= stump->right;
               }

               stump->release( (CvClassifier**)&stump);       
               
           }

           stumpTrainParams.getTrainData = icvGetTrainingDataCallback;
           stumpTrainParams.numcomp = n;
           stumpTrainParams.userdata = &userdata;
           stumpTrainParams.sortedIdx = data->idxcache;

#ifdef CV_VERBOSE
           v_flipped = 1;
#endif

       }
       if( trimmedIdx != sampleIdx )
       {
           cvReleaseMat( &trimmedIdx );
           trimmedIdx = NULL;
       }
       
       for( i = 0; i < numsamples; i++ )
       {
           idx = icvGetIdxAt( sampleIdx, i );

           eval.data.fl[idx] = classifier->eval_r((CvIntHaarClassifier*) classifier,
               (sum_type*) (data->sum.data.ptr + idx *data->sum.step),
               (sum_type*) (data->tilted.data.ptr + idx *data->tilted.step),
               data->normfactor.data.fl[idx] );
       }

       alpha = cvBoostNextWeakClassifier( &eval,&data->cls, weakTrainVals,
                                          &data->weights, trainer );
       sumalpha += alpha;
       
       for( i = 0; i <= classifier->count;i++ )
       {
           if( boosttype == CV_RABCLASS )
           {
               classifier->val[i] = cvLogRatio(classifier->val[i] );
           }
           classifier->val[i] *= alpha;
       }

       cvSeqPush( seq, (void*) &classifier );

       numpos = 0;
       for( i = 0; i < numsamples; i++ )
       {
           idx = icvGetIdxAt( sampleIdx, i );

           if( data->cls.data.fl[idx] == 1.0F )
           {
               eval.data.fl[numpos] = 0.0F;
               for( j = 0; j < seq->total; j++)
               {
                   classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j));
                   eval.data.fl[numpos] += classifier->eval_r(
                       (CvIntHaarClassifier*) classifier,
                       (sum_type*) (data->sum.data.ptr + idx *data->sum.step),
                       (sum_type*) (data->tilted.data.ptr + idx *data->tilted.step),
                       data->normfactor.data.fl[idx] );
               }
               
               numpos++;
           }
       }
       icvSort_32f( eval.data.fl, numpos, 0 );
       threshold = eval.data.fl[(int) ((1.0F - minhitrate) * numpos)];

       numneg = 0;
       numfalse = 0;
       for( i = 0; i < numsamples; i++ )
       {
           idx = icvGetIdxAt( sampleIdx, i );

           if( data->cls.data.fl[idx] == 0.0F )
           {
               numneg++;
               sum_stage = 0.0F;
               for( j = 0; j < seq->total; j++)
               {
                  classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j));
                  sum_stage += classifier->eval_r((CvIntHaarClassifier*) classifier,
                       (sum_type*) (data->sum.data.ptr + idx *data->sum.step),
                       (sum_type*) (data->tilted.data.ptr + idx *data->tilted.step),
                       data->normfactor.data.fl[idx] );
               }
               
               if( sum_stage >= (threshold - CV_THRESHOLD_EPS))
               {
                   numfalse++;
               }
           }
       }
       falsealarm = ((float) numfalse) / ((float) numneg);

#ifdef CV_VERBOSE
       {
           floatv_hitrate    =0.0F;
           float v_falsealarm = 0.0F;
           
           float v_experr = 0.0F;

           for( i = 0; i < numsamples; i++ )
           {
               idx = icvGetIdxAt( sampleIdx, i );

               sum_stage = 0.0F;
               for( j = 0; j < seq->total; j++)
               {
                   classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j));
                   sum_stage += classifier->eval_r((CvIntHaarClassifier*) classifier,
                       (sum_type*) (data->sum.data.ptr + idx *data->sum.step),
                       (sum_type*) (data->tilted.data.ptr + idx *data->tilted.step),
                       data->normfactor.data.fl[idx] );
               }
               
               if( sum_stage >= (threshold - CV_THRESHOLD_EPS))
               {
                   if( data->cls.data.fl[idx] == 1.0F )
                   {
                       v_hitrate += 1.0F;
                   }
                   else
                   {
                       v_falsealarm += 1.0F;
                   }
               }
               if( ( sum_stage >= 0.0F ) !=(data->cls.data.fl[idx] == 1.0F) )
               {
                   v_experr += 1.0F;
               }
           }
           v_experr /= numsamples;
           printf( "|M|=%%|%c|�|�|�|�|\n",
               seq->total, v_wt, ( (v_flipped) ? '+' : '-' ),
               threshold, v_hitrate / numpos, v_falsealarm / numneg,
               v_experr );
           printf( "+----+----+-+---------+---------+---------+---------+\n");
           fflush( stdout );
       }
#endif
       
    } while(falsealarm > maxfalsealarm&& (!maxsplits || (num_splits< maxsplits) ) );
   cvBoostEndTraining( &trainer );

    if(falsealarm > maxfalsealarm )
    {
       stage = NULL;
    }
    else
    {
       stage = (CvStageHaarClassifier*) icvCreateStageHaarClassifier(seq->total,
                                                                      threshold );
       cvCvtSeqToArray( seq, (CvArr*) stage->classifier);
    }
   
   
   cvReleaseMemStorage( &storage );
   cvReleaseMat( &weakTrainVals );
    cvFree(&(eval.data.ptr) );
   
    return(CvIntHaarClassifier*) stage;
}

0 0
原创粉丝点击