cascadeclassifier.cpp

来源:互联网 发布:java大量数据处理调优 编辑:程序博客网 时间:2024/06/03 21:48
bool CvCascadeClassifier::train( const string _cascadeDirName,                                const string _posFilename,                                const string _negFilename,                                int _numPos, int _numNeg,                                int _precalcValBufSize, int _precalcIdxBufSize,                                int _numStages,                                const CvCascadeParams& _cascadeParams,                                const CvFeatureParams& _featureParams,                                const CvCascadeBoostParams& _stageParams,                                bool baseFormatSave ){    // Start recording clock ticks for training time output    const clock_t begin_time = clock();    if( _cascadeDirName.empty() || _posFilename.empty() || _negFilename.empty() )        CV_Error( CV_StsBadArg, "_cascadeDirName or _bgfileName or _vecFileName is NULL" );    string dirName;    if (_cascadeDirName.find_last_of("/\\") == (_cascadeDirName.length() - 1) )        dirName = _cascadeDirName;    else        dirName = _cascadeDirName + '/';    numPos = _numPos;//正样本数目    numNeg = _numNeg;//负样本数目    numStages = _numStages;    if ( !imgReader.create( _posFilename, _negFilename, _cascadeParams.winSize ) )//为读取正负样本做准备    {        cout << "Image reader can not be created from -vec " << _posFilename                << " and -bg " << _negFilename << "." << endl;        return false;    }    if ( !load( dirName ) )    {        cascadeParams = _cascadeParams;        featureParams = CvFeatureParams::create(cascadeParams.featureType);        featureParams->init(_featureParams);        stageParams = new CvCascadeBoostParams;        *stageParams = _stageParams;        featureEvaluator = CvFeatureEvaluator::create(cascadeParams.featureType);//为提取特征做准备        featureEvaluator->init( (CvFeatureParams*)featureParams, numPos + numNeg, cascadeParams.winSize );        stageClassifiers.reserve( numStages );    }    cout << "PARAMETERS:" << endl;    cout << "cascadeDirName: " << _cascadeDirName << endl;    cout << "vecFileName: " << _posFilename << endl;    cout << "bgFileName: " << _negFilename << endl;    cout << "numPos: " << _numPos << endl;    cout << "numNeg: " << _numNeg << endl;    cout << "numStages: " << numStages << endl;    cout << "precalcValBufSize[Mb] : " << _precalcValBufSize << endl;    cout << "precalcIdxBufSize[Mb] : " << _precalcIdxBufSize << endl;    cascadeParams.printAttrs();    stageParams->printAttrs();    featureParams->printAttrs();    int startNumStages = (int)stageClassifiers.size();    if ( startNumStages > 1 )        cout << endl << "Stages 0-" << startNumStages-1 << " are loaded" << endl;    else if ( startNumStages == 1)        cout << endl << "Stage 0 is loaded" << endl;    //比如默认为pow(0.5,numStages)/1。需要达到的虚警率,即FPR,也即代码中提到的Acceptance Rate    double requiredLeafFARate = pow( (double) stageParams->maxFalseAlarm, (double) numStages ) /                                (double)stageParams->max_depth;    double tempLeafFARate;    for( int i = startNumStages; i < numStages; i++ )//for each stage    {        cout << endl << "===== TRAINING " << i << "-stage =====" << endl;        cout << "<BEGIN" << endl;        if ( !updateTrainingSet( tempLeafFARate ) )//读取正负样本,并计算当前分类器的FPR,如果达到要求,则不用继续训练了        {            cout << "Train dataset for temp stage can not be filled. "                "Branch training terminated." << endl;            break;        }        if( tempLeafFARate <= requiredLeafFARate )//FPR达到要求,不用继续训练了        {            cout << "Required leaf false alarm rate achieved. "                 "Branch training terminated." << endl;            break;        }        CvCascadeBoost* tempStage = new CvCascadeBoost;//强分类器        bool isStageTrained = tempStage->train( (CvFeatureEvaluator*)featureEvaluator,                                                curNumSamples, _precalcValBufSize, _precalcIdxBufSize,                                                *((CvCascadeBoostParams*)stageParams) );//开始训练一个强分类器        cout << "END>" << endl;        if(!isStageTrained)//强分类器训练不成功,则结束训练            break;        stageClassifiers.push_back( tempStage );//保存强分类器        // save params        if( i == 0)        {            std::string paramsFilename = dirName + CC_PARAMS_FILENAME;            FileStorage fs( paramsFilename, FileStorage::WRITE);            if ( !fs.isOpened() )            {                cout << "Parameters can not be written, because file " << paramsFilename                        << " can not be opened." << endl;                return false;            }            fs << FileStorage::getDefaultObjectName(paramsFilename) << "{";            writeParams( fs );            fs << "}";        }        // save current stage        char buf[10];        sprintf(buf, "%s%d", "stage", i );        string stageFilename = dirName + buf + ".xml";        FileStorage fs( stageFilename, FileStorage::WRITE );        if ( !fs.isOpened() )        {            cout << "Current stage can not be written, because file " << stageFilename                    << " can not be opened." << endl;            return false;        }        fs << FileStorage::getDefaultObjectName(stageFilename) << "{";        tempStage->write( fs, Mat() );        fs << "}";        // Output training time up till now        float seconds = float( clock () - begin_time ) / CLOCKS_PER_SEC;        int days = int(seconds) / 60 / 60 / 24;        int hours = (int(seconds) / 60 / 60) % 24;        int minutes = (int(seconds) / 60) % 60;        int seconds_left = int(seconds) % 60;        cout << "Training until now has taken " << days << " days " << hours << " hours " << minutes << " minutes " << seconds_left <<" seconds." << endl;    }    if(stageClassifiers.size() == 0)    {        cout << "Cascade classifier can't be trained. Check the used training parameters." << endl;        return false;    }    save( dirName + CC_CASCADE_FILENAME, baseFormatSave );    return true;}
bool CvCascadeClassifier::updateTrainingSet( double& acceptanceRatio){    int64 posConsumed = 0, negConsumed = 0;    imgReader.restart();    int posCount = fillPassedSamples( 0, numPos, true, posConsumed );//装载正样本    if( !posCount )//如果没有读到正样本,失败        return false;    cout << "POS count : consumed   " << posCount << " : " << (int)posConsumed << endl;//输出读取到的正样本,消耗掉的正样本数量。    //先求读到的正样本占总正样本数量的比例,然后乘以负样本总量,就是应该读取的负样本数量。    int proNumNeg = cvRound( ( ((double)numNeg) * ((double)posCount) ) / numPos ); // apply only a fraction of negative samples. double is required since overflow is possible    int negCount = fillPassedSamples( posCount, proNumNeg, false, negConsumed );//装载负样本    if ( !negCount )        return false;    curNumSamples = posCount + negCount;//读取到的样本总量    acceptanceRatio = negConsumed == 0 ? 0 : ( (double)negCount/(double)(int64)negConsumed );//FP/PF+TN。被分类的负样本中,被错分为正样本的负样本比例。    cout << "NEG count : acceptanceRatio    " << negCount << " : " << acceptanceRatio << endl;    return true;}
int CvCascadeClassifier::fillPassedSamples( int first, int count, bool isPositive, int64& consumed ){    int getcount = 0;    Mat img(cascadeParams.winSize, CV_8UC1);//建立一个mat,保存图片,注意此矩阵的行数是高,列数是宽。    //i为样本的索引,如果已经读取10个正样本,现在要读取负样本了,那么负样本的序号要从11开始。    for( int i = first; i < first + count; i++ )//开始读指数数目的图片    {        for( ; ; )        {            //让我读正样本还是负样本?,正,读一个正样本;负,读一个负样本            bool isGetImg = isPositive ? imgReader.getPos( img ) :                                           imgReader.getNeg( img );             if( !isGetImg )//图片读取失败,返回已经读到的图片数量                return getcount;            consumed++;//消耗的图片数量+1;之所以这样记录是因为可能读取多个图片才能找到一个被当前级联分类器认为是正样本的图片。对于读取负样本,也是这个道理。            //读取到的图片保存到featureEvalutor里。            featureEvaluator->setImage( img, isPositive ? 1 : 0, i );            //预测读取到的图片,如果为正样本,则处理。此处对于读取正样本好理解。当读取负样本的时候为什么也要执行这段代码?这是因为,代码要去找被当前分类器错认为是正阳本的负样本,收集这样的负样本有利于训练一个具有较强分辨能力的级联分类器。            if( predict( i ) == 1.0F )            {                getcount++;                printf("%s current samples: %d\r", isPositive ? "POS":"NEG", getcount);                break;            }        }    }    return getcount;//返回找到的正样本或负样本数量。}
int CvCascadeClassifier::predict( int sampleIdx )/输入样本的索引{    CV_DbgAssert( sampleIdx < numPos + numNeg );//检查索引是否在指定范围内    for (vector< Ptr<CvCascadeBoost> >::iterator it = stageClassifiers.begin();        it != stageClassifiers.end(); it++ )    {        //对于当前的每一级分类器,只要有一个分类器认为此样本是负样本,那么就返回0(即认为此样本是负样本);               if ( (*it)->predict( sampleIdx ) == 0.f )            return 0;    }     //如果所有级的分类器都认为此样本是正样本,则返回1(认为此样本是正样本)    return 1;}
0 0
原创粉丝点击