opencv中adaboost训练算法分析

来源:互联网 发布:怎么把链接复制到淘宝 编辑:程序博客网 时间:2024/05/12 22:56

0、概述  

opencv集成了经典adaboost算法,并结合haar特征实现了人脸检测功能。算法原理可参考人脸检测大牛Paul Viola 的文章《Rapid Object Detection using a Boosted Cascade of Simple Feature》。由于该算法堪称经典,并可推广应用于其他相关检测识别领域(如车牌检测、车辆检测识别),因此有必要从源码上学习其实现过程。人脸检测说是检测,实际上关键算法体现在训练(training)模块,openCV2.4.4 内含haartraining代码,本篇博文及后续相关博文都基于training算法为说明对象。

1 预备知识

1. 什么是haar特征?
haar特征是众多图像特征中的一种,计算方法为将特征窗口内的像素相加或相减。典型特征窗口如下图所示;
这里写图片描述
2. 一副图像上haar特征为什么有很多,具体有多少个?
由于haar特征窗口可有很多种类别,如上图中的(A、B、C、D不限于此四类)类,每一种类又可以变化特征窗口的尺度,在类别和尺度都确定的基础上,haar特征窗口还可以在样本图像上平移滑动。因此,一副图像上可生成众多haar特征。
具体个数请参考其他博文。haar特征及个数计算
3. 为什么要计算haar特征?
每一个haar特征可视为一个弱分类器(后面会解释),训练过程就是选择弱分类器(haar特征)的过程。
4. 什么是积分图,为什么用积分图?
积分图可加快haar特征的计算。
5. 什么是adaboost算法,如何实现?
集成学习算法包括两大类算法:bagging 算法和boosting 算法。
adaboost属于集成学习算法boosting的一种实现,在《Rapid Object Detection using a Boosted Cascade of Simple Features》作者论文中给出了adaboost算法流程图:
这里写图片描述

需要说明的是,该流程与opencv实现的方法有些出入,如openCV将负样本标记为-1,而文
献中标记为0,因此强分类器的阈值判断也不同。
openCV实现的adaboost,更接近于此片博文adaboost算法原理推导
这里写图片描述
6. 什么是决策树,为什么要用决策树?
决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。Entropy = 系统的凌乱程度,使用算法ID3, C4.5和C5.0生成树算法使用熵。这一度量是基于信息学理论中熵的概念。(盗用百度百科)
这里写图片描述
为啥要知道决策树? 因为opencV利用CART树,实现的弱分类器训练。

2 分类器训练

  为了表述的严谨性,本博文涉及的几个常用名词定义如下:
  (1)弱分类器:在本博文就是指CART,当CART只有一个分裂节点时,CART退化为Stump。CART的每个分裂节点都由1个haar特征。
  (2)强分类器:在本文下就是指Stage。

2.1 分类器训练概览

知道了haar特征,知道了adaboost,那么openCV到底是如何实现训练分类器的呢? 整体框架如下图所示。算法采用Cascade Tree结构,cascade Tree内部由多个stage构成,stage内部又由多个CART构成。opencv 实现分类器训练的过程就是建立cacade Tree的过程。
这里写图片描述

分类器最基本的训练单元为CART,然后是生成stage,最后建立cascade,以下详细展开分析。
源码调用关系图如下:
这里写图片描述

2.2 CART创建

  CART创建的关键是搜索并获得有效的haar特征。所涉及的函数及调用关系如下:
  

CvClassifier* cvCreateMTStumpClassifier( CvMat* trainData, //训练样本                      int flags,   //行、列                      CvMat* trainClasses, //类别标识,正样本+1,负样本-1                      CvMat* /*typeMask*/,                      CvMat* missedMeasurementsMask,                      CvMat* compIdx,                      CvMat* sampleIdx,                      CvMat* weights,  // 样本权重                      CvClassifierTrainParams* trainParams //训练参数                         //stumperror分裂规则:"misclass", "gini", "entropy"                       )

cvCreateMTStumpClassifier 函数是创建CART的主函数,此过程是训练最耗时的一部分。此函数在内部调用findStumpThreshold_16s[stumperror],搜索弱分类器的阈值。

CvFindThresholdFunc findStumpThreshold_16s[4] = {        icvFindStumpThreshold_misc_16s,        icvFindStumpThreshold_gini_16s,        icvFindStumpThreshold_entropy_16s,        icvFindStumpThreshold_sq_16s    };

以上为函数指针,每个函数指针都是以宏实现。以熵衰减规则分裂节点函数(icvFindStumpThreshold_gini_16s)为例,定义如下

/* entropy error * err = - wpos * log(wpos / (wpos + wneg)) - wneg * log(wneg / (wpos + wneg)) */#define ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( suffix, type )                             \    ICV_DEF_FIND_STUMP_THRESHOLD( entropy_##suffix, type,                                \        wposl = 0.5F * ( wl + wyl ); //左分支点中正样本权重和                                                    \        wposr = 0.5F * ( wr + wyr ); //右分支点中正样本权重和                                                    \        curleft = 0.5F * ( 1.0F + curleft ); //左分支中正样本权重和/左分支正负样本权重总和(左分支中正样本所占的比例,超过0.5则分类为正)                                           \        curright = 0.5F * ( 1.0F + curright ); // 右分支正样本权重和/右分支正负样本权重总和(右分支中正样本所占比例,超过0.5则分类为负)                                         \        curlerror = currerror = 0.0F;                                                    \        if( curleft > CV_ENTROPY_THRESHOLD )  // 左分支熵                                           \            curlerror -= wposl * logf( curleft );                                        \        if( curleft < 1.0F - CV_ENTROPY_THRESHOLD )                                     \            curlerror -= (wl - wposl) * logf( 1.0F - curleft );                          \                                                                                         \        if( curright > CV_ENTROPY_THRESHOLD )   //右分支熵                                          \            currerror -= wposr * logf( curright );                                       \        if( curright < 1.0F - CV_ENTROPY_THRESHOLD )                                     \            currerror -= (wr - wposr) * logf( 1.0F - curright );                         \    )

寻找弱分类器阈值函数

#define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error )                              \CV_BOOST_IMPL int icvFindStumpThreshold_##suffix(                                              \        uchar* data, size_t datastep,                                                    \        uchar* wdata, size_t wstep,                                                      \        uchar* ydata, size_t ystep,                                                      \        uchar* idxdata, size_t idxstep, int num,                                         \        float* lerror,                                                                   \        float* rerror,                                                                   \        float* threshold, float* left, float* right,                                     \        float* sumw, float* sumwy, float* sumwyy )                                       \{                                                                                        \    int found = 0;                                                                       \    float wyl  = 0.0F;  //左分支,各样本权重乘以类别y并求和                                                                 \    float wl   = 0.0F;   // 左分支各样本权重求和                                                               \    float wyyl = 0.0F;   // 左分支,各样本类别的平方乘以权重后求和                                                                \    float wyr  = 0.0F;   //右分支,各样本权重乘以类别y并求和                                                                \    float wr   = 0.0F;   // 右分支各样本权重求和                                                                  \                                                                                         \    float curleft  = 0.0F;   //左分支,正样本权重和/总的权重和                                                            \    float curright = 0.0F;   //右分支,正样本权重和/总的权重和                                                            \    float* prevval = NULL;                                                               \    float* curval  = NULL;                                                               \    float curlerror = 0.0F;    //左分支,的熵                                                          \    float currerror = 0.0F;    //右分支,的熵                                                           \    float wposl;   // 分配到左分支的正样本权重和                                                                     \    float wposr;    //分配到右分支的正样本权重和                                                                     \                                                                                         \    int i = 0;                                                                           \    int idx = 0;                                                                         \                                                                                         \    wposl = wposr = 0.0F;                                                                \    if( *sumw == FLT_MAX )                                                               \    {                                                                                    \        /* calculate sums */                                                             \        float *y = NULL;                                                                 \        float *w = NULL;                                                                 \        float wy = 0.0F;                                                                 \                                                                                         \        *sumw   = 0.0F;                                                                  \        *sumwy  = 0.0F;                                                                  \        *sumwyy = 0.0F;                                                                  \        for( i = 0; i < num; i++ )                                                       \        {                                                                                \            idx = (int) ( *((type*) (idxdata + i*idxstep)) );                            \            w = (float*) (wdata + idx * wstep);                                          \            *sumw += *w;      //权重和                                                            \            y = (float*) (ydata + idx * ystep);                                          \            wy = (*w) * (*y);                                                            \            *sumwy += wy;   //类别权重和                                                             \            *sumwyy += wy * (*y);  //当y=+1或-1时,此值同 sumw                                                      \        }                                                                                \    }                                                                                    \                                                                                         \    for( i = 0; i < num; i++ )                                                           \    {                                                                                    \        idx = (int) ( *((type*) (idxdata + i*idxstep)) );                                \        curval = (float*) (data + idx * datastep);                                       \         /* for debug purpose */                                                         \        if( i > 0 ) assert( (*prevval) <= (*curval) );                                   \                                                                                         \        wyr  = *sumwy - wyl;                                                             \        wr   = *sumw  - wl;                                                              \                                                                                         \        if( wl > 0.0 ) curleft = wyl / wl;                                               \        else curleft = 0.0F;                                                             \                                                                                         \        if( wr > 0.0 ) curright = wyr / wr;                                              \        else curright = 0.0F;                                                            \                                                                                         \        error                                                                            \                                                                                         \            if( curlerror + currerror < (*lerror) + (*rerror) )                              \         查找到使熵最小的阈值点           {                                                                                \            (*lerror) = curlerror;  //左分支熵                                                     \            (*rerror) = currerror;  //右分支熵                                                     \            *threshold = *curval;    //阈值                                                    \            if( i > 0 ) {                                                                \                *threshold = 0.5F * (*threshold + *prevval);                             \            }                                                                            \            *left  = curleft;  // 左分支中,正样本(权重)占左分支总权重的比例                                                          \            *right = curright; //右分支,正样本(权重)占右分支总权重的比例                                                          \            found = 1;                                                                   \        }                                                                                \                                                                                         \        do           //计算左右权重和                                                                    \        {                                                                                \            wl  += *((float*) (wdata + idx * wstep));                                    \            wyl += (*((float*) (wdata + idx * wstep)))                                   \                * (*((float*) (ydata + idx * ystep)));                                   \            wyyl += *((float*) (wdata + idx * wstep))                                    \                * (*((float*) (ydata + idx * ystep)))                                    \                * (*((float*) (ydata + idx * ystep)));                                   \        }                                                                                \        while( (++i) < num &&                                                            \            ( *((float*) (data + (idx =                                                  \                (int) ( *((type*) (idxdata + i*idxstep))) ) * datastep))                 \                == *curval ) );                                                          \        --i;                                                                             \        prevval = curval;                                                                \    } /* for each value */                                                               \                                                                                         \    return found;                                                                        \}

2.3 Stage 生成

某个CART弱分类器创建完成,然后根据adaboost权重更新规则,计算此弱分类器的αi,在DAB(Discrete AdaBoost)模式下,左右分支节点输出便是预测值与αi的乘积。
  Stage内弱分类器的个数,由maxfalsealarm确定。如果一直达不到小于maxfalsealarm的要求,则需要训练更多的弱分类器放入Stage中。
  Stage的 阈值由minhitrate确定。多个弱分类器的联合(即强分类器)对正样本的预测概率不能小于minhitrate。

2.4 几个问题说明

1,为什么训练过程中经常出现卡死,无法进入下一个stage的情况?
假设参数输入时候的负样本个数为1000; 则每一级stage都需要从待选的所有负样本中获得1000个被上一级错分为正样本的负样本(体现为虚警率)。由于越往后虚警率越低,同样是获取1000个错分样本,后面需要遍历的负样本范围越来越大。当最终无法得到错分样本时,程序便进入死循环中。

staticint icvGetHaarTrainingData( CvHaarTrainingData* data, int first, int count,                            CvIntHaarClassifier* cascade,                            CvGetHaarTrainingDataCallback callback, void* userdata,                            int* consumed, double* acceptance_ratio ){    int i = 0;    ccounter_t getcount = 0;    ccounter_t thread_getcount = 0;    ccounter_t consumed_count;     ccounter_t thread_consumed_count;    /* private variables */    CvMat img;    CvMat sum;    CvMat tilted;    CvMat sqsum;    sum_type* sumdata;    sum_type* tilteddata;    float*    normfactor;    /* end private variables */    assert( data != NULL );    assert( first + count <= data->maxnum );    assert( cascade != NULL );    assert( callback != NULL );    // if( !cvbgdata ) return 0; this check needs to be done in the callback for BG    CCOUNTER_SET_ZERO(getcount);    CCOUNTER_SET_ZERO(thread_getcount);    CCOUNTER_SET_ZERO(consumed_count);    CCOUNTER_SET_ZERO(thread_consumed_count);    #ifdef CV_OPENMP    #pragma omp parallel private(img, sum, tilted, sqsum, sumdata, tilteddata, \                                 normfactor, thread_consumed_count, thread_getcount)    #endif /* CV_OPENMP */    {        sumdata    = NULL;        tilteddata = NULL;        normfactor = NULL;        CCOUNTER_SET_ZERO(thread_getcount);        CCOUNTER_SET_ZERO(thread_consumed_count);        int ok = 1;        img = cvMat( data->winsize.height, data->winsize.width, CV_8UC1,            cvAlloc( sizeof( uchar ) * data->winsize.height * data->winsize.width ) );        sum = cvMat( data->winsize.height + 1, data->winsize.width + 1,                     CV_SUM_MAT_TYPE, NULL );        tilted = cvMat( data->winsize.height + 1, data->winsize.width + 1,                        CV_SUM_MAT_TYPE, NULL );        sqsum = cvMat( data->winsize.height + 1, data->winsize.width + 1, CV_SQSUM_MAT_TYPE,                       cvAlloc( sizeof( sqsum_type ) * (data->winsize.height + 1)                                                     * (data->winsize.width + 1) ) );        #ifdef CV_OPENMP        #pragma omp for schedule(static, 1)        #endif /* CV_OPENMP */        for( i = first; (i < first + count); i++ )        {            if( !ok )                continue;            for( ; ; ) //当没有合适的负样本时,陷入死循环            {                ok = callback( &img, userdata );                if( !ok )                    break;                CCOUNTER_INC(thread_consumed_count);                sumdata = (sum_type*) (data->sum.data.ptr + i * data->sum.step);                tilteddata = (sum_type*) (data->tilted.data.ptr + i * data->tilted.step);                normfactor = data->normfactor.data.fl + i;                sum.data.ptr = (uchar*) sumdata;                tilted.data.ptr = (uchar*) tilteddata;                icvGetAuxImages( &img, &sum, &tilted, &sqsum, normfactor );                            if( cascade->eval( cascade, sumdata, tilteddata, *normfactor ) != 0.0F )//读取正样本时,有可能小于正样本个数;                                                                                        //读取负样本时,反应的是错分为正样本的个数                {                    CCOUNTER_INC(thread_getcount);                    break;                }            }#ifdef CV_VERBOSE            if( (i - first) % 500 == 0 )            {                fprintf( stderr, "%3d%%\r", (int) ( 100.0 * (i - first) / count ) );                fflush( stderr );            }#endif /* CV_VERBOSE */        }        cvFree( &(img.data.ptr) );        cvFree( &(sqsum.data.ptr) );        #ifdef CV_OPENMP        #pragma omp critical (c_consumed_count)        #endif /* CV_OPENMP */        {            /* consumed_count += thread_consumed_count; */            CCOUNTER_ADD(getcount, thread_getcount);            CCOUNTER_ADD(consumed_count, thread_consumed_count);        }    } /* omp parallel */    if( consumed != NULL )    {        *consumed = (int)consumed_count;    }    if( acceptance_ratio != NULL )    {        /* *acceptance_ratio = ((double) count) / consumed_count; */        *acceptance_ratio = CCOUNTER_DIV(count, consumed_count); // 计算虚警率    }    return static_cast<int>(getcount);}

2,负样本是怎么截取的?
将输入图像等比例缩小到最小尺寸(不小于样本尺寸),然后再逐渐放大至原始尺寸。期间,用样本同样大小(width,height)的窗口滑动(步长为width/2,height/2 ),截取负样本。

0 0
原创粉丝点击