【图像特征提取19】RANSAC算法原理与源码解析

来源:互联网 发布:js数据类型判断 编辑:程序博客网 时间:2024/04/29 11:53

本文转载自:http://blog.csdn.NET/luoshixian099/article/details/50217655


      随机抽样一致性(RANSAC)算法,可以在一组包含“外点”的数据集中,采用不断迭代的方法,寻找最优参数模型,不符合最优模型的点,被定义为“外点”。在图像配准以及拼接上得到广泛的应用,本文将对RANSAC算法在OpenCV中角点误匹配对的检测中进行解析。



1.RANSAC原理  

  OpenCV中滤除误匹配对采用RANSAC算法寻找一个最佳单应性矩阵H,矩阵大小为3×3。RANSAC目的是找到最优的参数矩阵使得满足该矩阵的数据点个数最多,通常令h33=1来归一化矩阵。由于单应性矩阵有8个未知参数,至少需要8个线性方程求解,对应到点位置信息上,一组点对可以列出两个方程,则至少包含4组匹配点对。


                                                                      其中(x,y)表示目标图像角点位置,(x',y')为场景图像角点位置,s为尺度参数

  RANSAC算法从匹配数据集中随机抽出4个样本并保证这4个样本之间不共线,计算出单应性矩阵,然后利用这个模型测试所有数据,并计算满足这个模型数据点的个数与投影误差(即代价函数),若此模型为最优模型,则对应的代价函数最小。


-----------------------------------------------------------------------------------------------------------------

RANSAC算法步骤: 

          1. 随机从数据集中随机抽出4个样本数据 (此4个样本之间不能共线),计算出变换矩阵H,记为模型M;

          2. 计算数据集中所有数据与模型M的投影误差,若误差小于阈值,加入内点集 I ;

          3. 如果当前内点集 I 元素个数大于最优内点集 I_best , 则更新 I_best = I,同时更新迭代次数k ;

          4. 如果迭代次数大于k,则退出 ; 否则迭代次数加1,并重复上述步骤;

  注:迭代次数k在不大于最大迭代次数的情况下,是在不断更新而不是固定的;

                                       其中,p为置信度,一般取0.995;w为"内点"的比例 ; m为计算模型所需要的最少样本数=4;

-----------------------------------------------------------------------------------------------------------------


2.例程

OpenCV中此功能通过调用findHomography函数调用,下面是个例程:

[cpp] view plain copy
 print?
  1. #include <iostream>  
  2. #include "opencv2/opencv.hpp"  
  3. #include "opencv2/core/core.hpp"  
  4. #include "opencv2/features2d/features2d.hpp"  
  5. #include "opencv2/highgui/highgui.hpp"  
  6. using namespace cv;  
  7. using namespace std;  
  8. int main(int argc, char** argv)  
  9. {  
  10.     Mat obj=imread("F:\\Picture\\obj.jpg");   //载入目标图像  
  11.     Mat scene=imread("F:\\Picture\\scene.jpg"); //载入场景图像  
  12.     if (obj.empty() || scene.empty() )  
  13.     {  
  14.         cout<<"Can't open the picture!\n";  
  15.         return 0;  
  16.     }  
  17.     vector<KeyPoint> obj_keypoints,scene_keypoints;  
  18.     Mat obj_descriptors,scene_descriptors;  
  19.     ORB detector;     //采用ORB算法提取特征点  
  20.     detector.detect(obj,obj_keypoints);  
  21.     detector.detect(scene,scene_keypoints);  
  22.     detector.compute(obj,obj_keypoints,obj_descriptors);  
  23.     detector.compute(scene,scene_keypoints,scene_descriptors);  
  24.     BFMatcher matcher(NORM_HAMMING,true); //汉明距离做为相似度度量  
  25.     vector<DMatch> matches;  
  26.     matcher.match(obj_descriptors, scene_descriptors, matches);  
  27.     Mat match_img;  
  28.     drawMatches(obj,obj_keypoints,scene,scene_keypoints,matches,match_img);  
  29.     imshow("滤除误匹配前",match_img);  
  30.   
  31.     //保存匹配对序号  
  32.     vector<int> queryIdxs( matches.size() ), trainIdxs( matches.size() );  
  33.     forsize_t i = 0; i < matches.size(); i++ )  
  34.     {  
  35.         queryIdxs[i] = matches[i].queryIdx;  
  36.         trainIdxs[i] = matches[i].trainIdx;  
  37.     }     
  38.   
  39.     Mat H12;   //变换矩阵  
  40.   
  41.     vector<Point2f> points1; KeyPoint::convert(obj_keypoints, points1, queryIdxs);  
  42.     vector<Point2f> points2; KeyPoint::convert(scene_keypoints, points2, trainIdxs);  
  43.     int ransacReprojThreshold = 5;  //拒绝阈值  
  44.   
  45.   
  46.     H12 = findHomography( Mat(points1), Mat(points2), CV_RANSAC, ransacReprojThreshold );  
  47.     vector<char> matchesMask( matches.size(), 0 );    
  48.     Mat points1t;  
  49.     perspectiveTransform(Mat(points1), points1t, H12);  
  50.     forsize_t i1 = 0; i1 < points1.size(); i1++ )  //保存‘内点’  
  51.     {  
  52.         if( norm(points2[i1] - points1t.at<Point2f>((int)i1,0)) <= ransacReprojThreshold ) //给内点做标记  
  53.         {  
  54.             matchesMask[i1] = 1;  
  55.         }     
  56.     }  
  57.     Mat match_img2;   //滤除‘外点’后  
  58.     drawMatches(obj,obj_keypoints,scene,scene_keypoints,matches,match_img2,Scalar(0,0,255),Scalar::all(-1),matchesMask);  
  59.   
  60.     //画出目标位置  
  61.     std::vector<Point2f> obj_corners(4);  
  62.     obj_corners[0] = cvPoint(0,0); obj_corners[1] = cvPoint( obj.cols, 0 );  
  63.     obj_corners[2] = cvPoint( obj.cols, obj.rows ); obj_corners[3] = cvPoint( 0, obj.rows );  
  64.     std::vector<Point2f> scene_corners(4);  
  65.     perspectiveTransform( obj_corners, scene_corners, H12);  
  66.     line( match_img2, scene_corners[0] + Point2f(static_cast<float>(obj.cols), 0),   
  67.         scene_corners[1] + Point2f(static_cast<float>(obj.cols), 0),Scalar(0,0,255),2);  
  68.     line( match_img2, scene_corners[1] + Point2f(static_cast<float>(obj.cols), 0),   
  69.         scene_corners[2] + Point2f(static_cast<float>(obj.cols), 0),Scalar(0,0,255),2);  
  70.     line( match_img2, scene_corners[2] + Point2f(static_cast<float>(obj.cols), 0),   
  71.         scene_corners[3] + Point2f(static_cast<float>(obj.cols), 0),Scalar(0,0,255),2);  
  72.     line( match_img2, scene_corners[3] + Point2f(static_cast<float>(obj.cols), 0),  
  73.         scene_corners[0] + Point2f(static_cast<float>(obj.cols), 0),Scalar(0,0,255),2);  
  74.   
  75.     imshow("滤除误匹配后",match_img2);  
  76.     waitKey(0);  
  77.       
  78.     return 0;  
  79. }  

3. RANSAC源码解析

(1)findHomography内部调用cvFindHomography函数

srcPoints:目标图像点集

dstPoints:场景图像点集

method: 最小中值法、RANSAC方法、最小二乘法

ransacReprojTheshold:投影误差阈值

mask:掩码

[cpp] view plain copy
 print?
  1. cvFindHomography( const CvMat* objectPoints, const CvMat* imagePoints,  
  2.                   CvMat* __H, int method, double ransacReprojThreshold,  
  3.                   CvMat* mask )  
  4. {  
  5.     const double confidence = 0.995;  //置信度  
  6.     const int maxIters = 2000;    //迭代最大次数  
  7.     const double defaultRANSACReprojThreshold = 3; //默认拒绝阈值  
  8.     bool result = false;  
  9.     Ptr<CvMat> m, M, tempMask;  
  10.   
  11.     double H[9];  
  12.     CvMat matH = cvMat( 3, 3, CV_64FC1, H );  //变换矩阵  
  13.     int count;  
  14.   
  15.     CV_Assert( CV_IS_MAT(imagePoints) && CV_IS_MAT(objectPoints) );  
  16.   
  17.     count = MAX(imagePoints->cols, imagePoints->rows);  
  18.     CV_Assert( count >= 4 );           //至少有4个样本  
  19.     if( ransacReprojThreshold <= 0 )  
  20.         ransacReprojThreshold = defaultRANSACReprojThreshold;  
  21.   
  22.     m = cvCreateMat( 1, count, CV_64FC2 );  //转换为齐次坐标  
  23.     cvConvertPointsHomogeneous( imagePoints, m );  
  24.   
  25.     M = cvCreateMat( 1, count, CV_64FC2 );//转换为齐次坐标  
  26.     cvConvertPointsHomogeneous( objectPoints, M );  
  27.   
  28.     if( mask )  
  29.     {  
  30.         CV_Assert( CV_IS_MASK_ARR(mask) && CV_IS_MAT_CONT(mask->type) &&  
  31.             (mask->rows == 1 || mask->cols == 1) &&  
  32.             mask->rows*mask->cols == count );  
  33.     }  
  34.     if( mask || count > 4 )  
  35.         tempMask = cvCreateMat( 1, count, CV_8U );  
  36.     if( !tempMask.empty() )  
  37.         cvSet( tempMask, cvScalarAll(1.) );  
  38.   
  39.     CvHomographyEstimator estimator(4);  
  40.     if( count == 4 )  
  41.         method = 0;  
  42.     if( method == CV_LMEDS )  //最小中值法  
  43.         result = estimator.runLMeDS( M, m, &matH, tempMask, confidence, maxIters );  
  44.     else if( method == CV_RANSAC )    //采用RANSAC算法  
  45.         result = estimator.runRANSAC( M, m, &matH, tempMask, ransacReprojThreshold, confidence, maxIters);//(2)解析  
  46.     else  
  47.         result = estimator.runKernel( M, m, &matH ) > 0; //最小二乘法  
  48.   
  49.     if( result && count > 4 )  
  50.     {  
  51.         icvCompressPoints( (CvPoint2D64f*)M->data.ptr, tempMask->data.ptr, 1, count );  //保存内点集  
  52.         count = icvCompressPoints( (CvPoint2D64f*)m->data.ptr, tempMask->data.ptr, 1, count );  
  53.         M->cols = m->cols = count;  
  54.         if( method == CV_RANSAC )  //  
  55.             estimator.runKernel( M, m, &matH );  //内点集上采用最小二乘法重新估算变换矩阵  
  56.         estimator.refine( M, m, &matH, 10 );//   
  57.     }  
  58.   
  59.     if( result )  
  60.         cvConvert( &matH, __H );  //保存变换矩阵  
  61.   
  62.     if( mask && tempMask )  
  63.     {  
  64.         if( CV_ARE_SIZES_EQ(mask, tempMask) )  
  65.            cvCopy( tempMask, mask );  
  66.         else  
  67.            cvTranspose( tempMask, mask );  
  68.     }  
  69.   
  70.     return (int)result;  
  71. }  

(2) runRANSAC
maxIters:最大迭代次数

m1,m2 :数据样本

model:变换矩阵

reprojThreshold:投影误差阈值

confidence:置信度  0.995

[cpp] view plain copy
 print?
  1. bool CvModelEstimator2::runRANSAC( const CvMat* m1, const CvMat* m2, CvMat* model,  
  2.                                     CvMat* mask0, double reprojThreshold,  
  3.                                     double confidence, int maxIters )  
  4. {  
  5.     bool result = false;  
  6.     cv::Ptr<CvMat> mask = cvCloneMat(mask0);  
  7.     cv::Ptr<CvMat> models, err, tmask;  
  8.     cv::Ptr<CvMat> ms1, ms2;  
  9.   
  10.     int iter, niters = maxIters;  
  11.     int count = m1->rows*m1->cols, maxGoodCount = 0;  
  12.     CV_Assert( CV_ARE_SIZES_EQ(m1, m2) && CV_ARE_SIZES_EQ(m1, mask) );  
  13.   
  14.     if( count < modelPoints )  
  15.         return false;  
  16.   
  17.     models = cvCreateMat( modelSize.height*maxBasicSolutions, modelSize.width, CV_64FC1 );  
  18.     err = cvCreateMat( 1, count, CV_32FC1 );//保存每组点对应的投影误差  
  19.     tmask = cvCreateMat( 1, count, CV_8UC1 );  
  20.   
  21.     if( count > modelPoints )  //多于4个点  
  22.     {  
  23.         ms1 = cvCreateMat( 1, modelPoints, m1->type );  
  24.         ms2 = cvCreateMat( 1, modelPoints, m2->type );  
  25.     }  
  26.     else  //等于4个点  
  27.     {  
  28.         niters = 1; //迭代一次  
  29.         ms1 = cvCloneMat(m1);  //保存每次随机找到的样本点  
  30.         ms2 = cvCloneMat(m2);  
  31.     }  
  32.   
  33.     for( iter = 0; iter < niters; iter++ )  //不断迭代  
  34.     {  
  35.         int i, goodCount, nmodels;  
  36.         if( count > modelPoints )  
  37.         {  
  38.            //在(3)解析getSubset  
  39.             bool found = getSubset( m1, m2, ms1, ms2, 300 ); //随机选择4组点,且三点之间不共线,(3)解析  
  40.             if( !found )  
  41.             {  
  42.                 if( iter == 0 )  
  43.                     return false;  
  44.                 break;  
  45.             }  
  46.         }  
  47.   
  48.         nmodels = runKernel( ms1, ms2, models );  //估算近似变换矩阵,返回1  
  49.         if( nmodels <= 0 )  
  50.             continue;  
  51.         for( i = 0; i < nmodels; i++ )//执行一次   
  52.         {  
  53.             CvMat model_i;  
  54.             cvGetRows( models, &model_i, i*modelSize.height, (i+1)*modelSize.height );//获取3×3矩阵元素  
  55.             goodCount = findInliers( m1, m2, &model_i, err, tmask, reprojThreshold );  //找出内点,(4)解析  
  56.   
  57.             if( goodCount > MAX(maxGoodCount, modelPoints-1) )  //当前内点集元素个数大于最优内点集元素个数  
  58.             {  
  59.                 std::swap(tmask, mask);  
  60.                 cvCopy( &model_i, model );  //更新最优模型  
  61.                 maxGoodCount = goodCount;  
  62.                 //confidence 为置信度默认0.995,modelPoints为最少所需样本数=4,niters迭代次数  
  63.                 niters = cvRANSACUpdateNumIters( confidence,  //更新迭代次数,(5)解析  
  64.                     (double)(count - goodCount)/count, modelPoints, niters );  
  65.             }  
  66.         }  
  67.     }  
  68.   
  69.     if( maxGoodCount > 0 )  
  70.     {  
  71.         if( mask != mask0 )  
  72.             cvCopy( mask, mask0 );  
  73.         result = true;  
  74.     }  
  75.   
  76.     return result;  
  77. }  

(3)getSubset

ms1,ms2:保存随机找到的4个样本

maxAttempts:最大迭代次数,300

[cpp] view plain copy
 print?
  1. bool CvModelEstimator2::getSubset( const CvMat* m1, const CvMat* m2,  
  2.                                    CvMat* ms1, CvMat* ms2, int maxAttempts )  
  3. {  
  4.     cv::AutoBuffer<int> _idx(modelPoints); //modelPoints 所需要最少的样本点个数  
  5.     int* idx = _idx;  
  6.     int i = 0, j, k, idx_i, iters = 0;  
  7.     int type = CV_MAT_TYPE(m1->type), elemSize = CV_ELEM_SIZE(type);  
  8.     const int *m1ptr = m1->data.i, *m2ptr = m2->data.i;  
  9.     int *ms1ptr = ms1->data.i, *ms2ptr = ms2->data.i;  
  10.     int count = m1->cols*m1->rows;  
  11.   
  12.     assert( CV_IS_MAT_CONT(m1->type & m2->type) && (elemSize % sizeof(int) == 0) );  
  13.     elemSize /= sizeof(int); //每个数据占用字节数  
  14.   
  15.     for(; iters < maxAttempts; iters++)  
  16.     {  
  17.         for( i = 0; i < modelPoints && iters < maxAttempts; )  
  18.         {  
  19.             idx[i] = idx_i = cvRandInt(&rng) % count;  // 随机选取1组点  
  20.             for( j = 0; j < i; j++ )  // 检测是否重复选择  
  21.                 if( idx_i == idx[j] )  
  22.                     break;  
  23.             if( j < i )  
  24.                 continue;  //重新选择  
  25.             for( k = 0; k < elemSize; k++ )  //拷贝点数据  
  26.             {  
  27.                 ms1ptr[i*elemSize + k] = m1ptr[idx_i*elemSize + k];  
  28.                 ms2ptr[i*elemSize + k] = m2ptr[idx_i*elemSize + k];  
  29.             }  
  30.             if( checkPartialSubsets && (!checkSubset( ms1, i+1 ) || !checkSubset( ms2, i+1 )))//检测点之间是否共线  
  31.             {  
  32.                 iters++;  //若共线,重新选择一组  
  33.                 continue;  
  34.             }  
  35.             i++;  
  36.         }  
  37.         if( !checkPartialSubsets && i == modelPoints &&  
  38.             (!checkSubset( ms1, i ) || !checkSubset( ms2, i )))  
  39.             continue;  
  40.         break;  
  41.     }  
  42.   
  43.     return i == modelPoints && iters < maxAttempts;  //返回找到的样本点个数  
  44. }  

(4) findInliers & computerReprojError

[cpp] view plain copy
 print?
  1. int CvModelEstimator2::findInliers( const CvMat* m1, const CvMat* m2,  
  2.                                     const CvMat* model, CvMat* _err,  
  3.                                     CvMat* _mask, double threshold )  
  4. {  
  5.     int i, count = _err->rows*_err->cols, goodCount = 0;  
  6.     const float* err = _err->data.fl;  
  7.     uchar* mask = _mask->data.ptr;  
  8.   
  9.     computeReprojError( m1, m2, model, _err );  //计算每组点的投影误差  
  10.     threshold *= threshold;  
  11.     for( i = 0; i < count; i++ )  
  12.         goodCount += mask[i] = err[i] <= threshold;//误差在限定范围内,加入‘内点集’  
  13.     return goodCount;  
  14. }  
  15. //--------------------------------------  
  16. void CvHomographyEstimator::computeReprojError( const CvMat* m1, const CvMat* m2,const CvMat* model, CvMat* _err )  
  17. {  
  18.     int i, count = m1->rows*m1->cols;  
  19.     const CvPoint2D64f* M = (const CvPoint2D64f*)m1->data.ptr;  
  20.     const CvPoint2D64f* m = (const CvPoint2D64f*)m2->data.ptr;  
  21.     const double* H = model->data.db;  
  22.     float* err = _err->data.fl;  
  23.   
  24.     for( i = 0; i < count; i++ )        //保存每组点的投影误差,对应上述变换公式  
  25.     {  
  26.         double ww = 1./(H[6]*M[i].x + H[7]*M[i].y + 1.);      
  27.         double dx = (H[0]*M[i].x + H[1]*M[i].y + H[2])*ww - m[i].x;  
  28.         double dy = (H[3]*M[i].x + H[4]*M[i].y + H[5])*ww - m[i].y;  
  29.         err[i] = (float)(dx*dx + dy*dy);  
  30.     }  
  31. }  
(5)cvRANSACUpdateNumIters

对应上述k的计算公式

p:置信度

ep:外点比例

[cpp] view plain copy
 print?
  1. cvRANSACUpdateNumIters( double p, double ep,  
  2.                         int model_points, int max_iters )  
  3. {  
  4.     if( model_points <= 0 )  
  5.         CV_Error( CV_StsOutOfRange, "the number of model points should be positive" );  
  6.   
  7.     p = MAX(p, 0.);  
  8.     p = MIN(p, 1.);  
  9.     ep = MAX(ep, 0.);  
  10.     ep = MIN(ep, 1.);  
  11.   
  12.     // avoid inf's & nan's  
  13.     double num = MAX(1. - p, DBL_MIN);  //num=1-p,做分子  
  14.     double denom = 1. - pow(1. - ep,model_points);//做分母  
  15.     if( denom < DBL_MIN )  
  16.         return 0;  
  17.   
  18.     num = log(num);  
  19.     denom = log(denom);  
  20.   
  21.     return denom >= 0 || -num >= max_iters*(-denom) ?  
  22.         max_iters : cvRound(num/denom);  
  23. }  
0 0