cvBoostStartTraining, cvBoostNextWeakClassifier和 cvBoostEndTraining

来源:互联网 发布:js 数组元素能是函数么 编辑:程序博客网 时间:2024/04/29 16:18




/****************************************************************************************\*                                        Boosting                                        *\****************************************************************************************/typedef struct CvBoostTrainer{    CvBoostType type;      //一共四类如下   /* CV_DABCLASS = 0, // 2 class Discrete AdaBoost              CV_RABCLASS = 1, // 2 class Real AdaBoost                   CV_LBCLASS  = 2, // 2 class LogitBoost                      CV_GABCLASS = 3, //2 class Gentle AdaBoost             */   int count;             /* (idx) ? number_of_indices : number_of_samples */    int* idx;    float* F;} CvBoostTrainer;/* * cvBoostStartTraining, cvBoostNextWeakClassifier, cvBoostEndTraining * * These functions perform training of 2-class boosting classifier * using ANY appropriate weak classifier */staticCvBoostTrainer* icvBoostStartTraining( CvMat* trainClasses,     //训练样本的类别为0,1                                       CvMat* weakTrainVals,    //训练的弱分类器的输出值,-1和1                                       CvMat* /*weights*/,      //样本权重向量                                       CvMat* sampleIdx,        //正负样本索引                                       CvBoostType type )       //类型如上{    uchar* ydata;    int ystep;    int m;    uchar* traindata;    int trainstep;    int trainnum;    int i;    int idx;    size_t datasize;    CvBoostTrainer* ptr;                                      //该函数中这个最为重要    int idxnum;    int idxstep;    uchar* idxdata;    assert( trainClasses != NULL );    assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );    assert( weakTrainVals != NULL );    assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );    CV_MAT2VEC( *trainClasses, ydata, ystep, m );    CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );    CV_Assert( m == trainnum );    idxnum = 0;    idxstep = 0;    idxdata = NULL;    if( sampleIdx )    {        CV_MAT2VEC( *sampleIdx, idxdata, idxstep, idxnum );    }  /*******************************ptr的初始化*********************************************/   datasize = sizeof( *ptr ) + sizeof( *ptr->idx ) * idxnum;    ptr = (CvBoostTrainer*) cvAlloc( datasize );         //为ptr分配内存    memset( ptr, 0, datasize );                          //初始化ptr,全部为0    ptr->F = NULL;    ptr->idx = NULL;    ptr->count = m;    ptr->type = type;    if( idxnum > 0 )    {        CvScalar s;       //s内部是四个double型的val,分别为val[0],val[1],val[2]val[3]        ptr->idx = (int*) (ptr + 1);        ptr->count = idxnum;        for( i = 0; i < ptr->count; i++ )        {           //将原始数据转化为cvScale类型的数据                       cvRawDataToScalar( idxdata + i*idxstep, CV_MAT_TYPE( sampleIdx->type ), &s );            ptr->idx[i] = (int) s.val[0];        }    }    for( i = 0; i < ptr->count; i++ )    {        idx = (ptr->idx) ? ptr->idx[i] : i;        *((float*) (traindata + idx * trainstep)) =            2.0F * (*((float*) (ydata + idx * ystep))) - 1.0F;////y*=2y-1,类别标签由{0,1}变为{-1,1}    }    return ptr;}/* * * Discrete AdaBoost functions *根据训练出来的结果与标签进行比较,更新全部样本权重 */staticfloat icvBoostNextWeakClassifierDAB( CvMat* weakEvalVals,                                     CvMat* trainClasses,                                     CvMat* /*weakTrainVals*/,                                     CvMat* weights,                                     CvBoostTrainer* trainer ){    uchar* evaldata;    int evalstep;    int m;    uchar* ydata;    int ystep;    int ynum;    uchar* wdata;    int wstep;    int wnum;    float sumw;    float err;    int i;    int idx;    CV_Assert( weakEvalVals != NULL );    CV_Assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );    CV_Assert( trainClasses != NULL );    CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );    CV_Assert( weights != NULL );    CV_Assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );    CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );    CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );    CV_MAT2VEC( *weights, wdata, wstep, wnum );    CV_Assert( m == ynum );    CV_Assert( m == wnum );    sumw = 0.0F;    err = 0.0F;    for( i = 0; i < trainer->count; i++ )    {        idx = (trainer->idx) ? trainer->idx[i] : i;        sumw += *((float*) (wdata + idx*wstep));                   //所有训练样本权重和        err += (*((float*) (wdata + idx*wstep))) *            ( (*((float*) (evaldata + idx*evalstep))) !=                2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F );  //训练错误样本的权重和    }    err /= sumw;                                                    //错误率比值    err = -cvLogRatio( err );                                       //取对数后,再取相反数,目的是把把err变成正值    for( i = 0; i < trainer->count; i++ )    {        idx = (trainer->idx) ? trainer->idx[i] : i;        *((float*) (wdata + idx*wstep)) *= expf( err *            ((*((float*) (evaldata + idx*evalstep))) !=                2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F) );//根据训练的结果正确与否,用指数函数更新权重。        sumw += *((float*) (wdata + idx*wstep));                   //更新权重后再重新计算全部样本的权重和    }    for( i = 0; i < trainer->count; i++ )    {        idx = (trainer->idx) ? trainer->idx[i] : i;        *((float*) (wdata + idx * wstep)) /= sumw;                 //把更新后的训练样本权重归一化    }    return err;                                                    //返回err,注意这个err是取对数后,再取相反数的那个err,也就是上文程序中最后那个err}typedef CvBoostTrainer* (*CvBoostStartTraining)( CvMat* trainClasses,                                                 CvMat* weakTrainVals,                                                 CvMat* weights,                                                 CvMat* sampleIdx,                                                 CvBoostType type );typedef float (*CvBoostNextWeakClassifier)( CvMat* weakEvalVals,                                            CvMat* trainClasses,                                            CvMat* weakTrainVals,                                            CvMat* weights,                                            CvBoostTrainer* data );CvBoostStartTraining startTraining[4] = {        icvBoostStartTraining,        icvBoostStartTraining,        icvBoostStartTrainingLB,        icvBoostStartTraining    };CvBoostNextWeakClassifier nextWeakClassifier[4] = {        icvBoostNextWeakClassifierDAB,        icvBoostNextWeakClassifierRAB,        icvBoostNextWeakClassifierLB,        icvBoostNextWeakClassifierGAB    };/* * * Dispatchers * */CV_BOOST_IMPLCvBoostTrainer* cvBoostStartTraining( CvMat* trainClasses,                                      CvMat* weakTrainVals,                                      CvMat* weights,                                      CvMat* sampleIdx,                                      CvBoostType type ){    return startTraining[type]( trainClasses, weakTrainVals, weights, sampleIdx, type );}CV_BOOST_IMPLvoid cvBoostEndTraining( CvBoostTrainer** trainer ){    cvFree( trainer );    *trainer = NULL;}CV_BOOST_IMPLfloat cvBoostNextWeakClassifier( CvMat* weakEvalVals,                                 CvMat* trainClasses,                                 CvMat* weakTrainVals,                                 CvMat* weights,                                 CvBoostTrainer* trainer ){    return nextWeakClassifier[trainer->type]( weakEvalVals, trainClasses,        weakTrainVals, weights, trainer    );}


0 0
原创粉丝点击