在opencv中,强分类器阈值是如何确定的?虚警率是怎么计算的?

来源:互联网 发布:直播实时转播软件 编辑:程序博客网 时间:2024/04/30 01:20

在opencv中,强分类器阈值的确定实在函数icvCreateCARTStageClassifier中,具体强分类器的阈值的求解方式和虚警率的计算如下

CvIntHaarClassifier* icvCreateCARTStageClassifier(       CvHaarTrainingData* data,        // 训练样本数据,包括图片的大小,数量,积分图,权重,类别等数据       CvMat* sampleIdx,                // 训练样本序列,不一定与积分图的顺序一致       CvIntHaarFeatures* haarFeatures, // 全部HAAR特征       float minhitrate,                // 最小正检率       float maxfalsealarm,             // 最大误检率                 int   symmetric,                 // HAAR特征是否对称       float weightfraction,            // 样本剔除比例(用于剔除小权值样本,以加快训练速度)       int numsplits,                   // 每个弱分类器特征个数(一般为1)       CvBoostType boosttype,           // adaboost类型,一般使用的是DAB       CvStumpError stumperror,         // Discrete AdaBoost(DAB)中的阈值计算方式       int maxsplits )                  // 弱分类器最大个数{               .................                //我把上面强分类器的count个弱分类器的求解省略了            numpos = 0;  //在data中正样本的数量对于程序来说是不知道的,那就需要在程序中求出           /*确定强分类器的阈值threshold          *遍历sampleIdx中所有样本,但是只需计算每个正样本的弱分类器置信度和,具体来说,          *也就是,对于每个正样本,遍历上面求出的所有求出count(在这里是seq->total)个          *弱分类器的置信度并求和,这样共得到numpos个正样本的置信度和,把他们升序排列,          *然后就可以求阈值了,阈值为:          *threshold = eval.data.fl[(int) ((1.0F - minhitrate) * numpos)]。          *这个阈值的含义是根据需要的最小击中率,可以先求出正样本的漏检率,然后乘以正样          *本的数量,这个值转化为int型,就是正样本中漏检的数量,那么对于上面刚排好序第几          *个正样本的弱分类器的置信度和。          */        for( i = 0; i < numsamples; i++ )          {              // 获得样本序号,可能与积分图中样本顺序不一致,所以要求出其序号              idx = icvGetIdxAt( sampleIdx, i );                // 如果样本为正样本              if( data->cls.data.fl[idx] == 1.0F )              {                  // 初始化置信度值                  eval.data.fl[numpos] = 0.0F;                    // 遍历seq中所有弱分类器                  for( j = 0; j < seq->total; j++ )                  {                      // 获取弱分类器                      classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j ));                        // 累积当前正样本的弱分类器置信度和                      eval.data.fl[numpos] += classifier->eval(                           (CvIntHaarClassifier*) classifier,                          (sum_type*) (data->sum.data.ptr + idx * data->sum.step),                          (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),                          data->normfactor.data.fl[idx] );                  }                  /* eval.data.fl[numpos] = 2.0F * eval.data.fl[numpos] - seq->total; */                  numpos++;              }          }            // 对弱分类器输出置信度和进行排序          icvSort_32f( eval.data.fl, numpos, 0 );            // 计算阈值,应该是大于threshold则为正类,小于threshold则为负类          threshold = eval.data.fl[(int) ((1.0F - minhitrate) * numpos)];            numneg = 0;          numfalse = 0;             /*确定强分类器的虚警率falsealarm          *遍历sampleIdx中所有样本,但是只需计算每个负样本的弱分类器置信度和,具体来说,          *也就是,对于每个负样本,遍历上面求出的所有求出count(在这里是seq->total)个          *弱分类器的置信度并求和,然后和上面所求强分类器的阈值相比较,即:          *if( sum_stage >= (threshold - CV_THRESHOLD_EPS) )                  {                      numfalse++;                  }            *这样就可以计算出被分类错误的负样本的数量了          */        for( i = 0; i < numsamples; i++ )          {              idx = icvGetIdxAt( sampleIdx, i );                // 如果样本为负样本              if( data->cls.data.fl[idx] == 0.0F )              {                  numneg++;                  sum_stage = 0.0F;                    // 遍历seq中所有弱分类器                  for( j = 0; j < seq->total; j++ )                  {                     classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j ));                       // 累积当前负样本的分类器输出结果                     sum_stage += classifier->eval( (CvIntHaarClassifier*) classifier,                          (sum_type*) (data->sum.data.ptr + idx * data->sum.step),                          (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),                          data->normfactor.data.fl[idx] );                  }                  /* sum_stage = 2.0F * sum_stage - seq->total; */                    // 因为小于threshold为负类,所以下面是分类错误的情况                  if( sum_stage >= (threshold - CV_THRESHOLD_EPS) )                  {                      numfalse++;                  }              }          }            // 计算虚警率          falsealarm = ((float) numfalse) / ((float) numneg);                       ...............}



1 0