RANSAC(随机采样一致算法)原理及openCV代码实现

来源:互联网 发布:神武2白无忧优化 编辑:程序博客网 时间:2024/05/24 01:49


本文转自:http://blog.csdn.net/yihaizhiyan/article/details/5973729

http://blog.csdn.net/Sway_2012/article/details/37765765

http://blog.csdn.net/zouwen198317/article/details/38494149

1.什么是RANSAC?

RANSAC是RANdom SAmple Consensus(随机抽样一致性)的缩写。它是从一个观察数据集合中,估计模型参数(模型拟合)的迭代方法。它是一种随机的不确定算法,每次运算求出的结果可能不相同,但总能给出一个合理的结果,为了提高概率必须提高迭代次数。

2.算法详解


给定两个点p1与p2的坐标,确定这两点所构成的直线,要求对于输入的任意点p3,都可以判断它是否在该直线上。初中解析几何知识告诉我们,判断一个点在直线上,只需其与直线上任意两点点斜率都相同即可。实际操作当中,往往会先根据已知的两点算出直线的表达式(点斜式、截距式等等),然后通过向量计算即可方便地判断p3是否在该直线上。 

生产实践中的数据往往会有一定的偏差。例如我们知道两个变量X与Y之间呈线性关系,Y=aX+b,我们想确定参数a与b的具体值。通过实验,可以得到一组X与Y的测试值。虽然理论上两个未知数的方程只需要两组值即可确认,但由于系统误差的原因,任意取两点算出的a与b的值都不尽相同。我们希望的是,最后计算得出的理论模型与测试值的误差最小。大学的高等数学课程中,详细阐述了最小二乘法的思想。通过计算最小均方差关于参数a、b的偏导数为零时的值。事实上,在很多情况下,最小二乘法都是线性回归的代名词。 

遗憾的是,最小二乘法只适合与误差较小的情况。试想一下这种情况,假使需要从一个噪音较大的数据集中提取模型(比方说只有20%的数据时符合模型的)时,最小二乘法就显得力不从心了。例如下图,肉眼可以很轻易地看出一条直线(模式),但算法却找错了。 



RANSAC算法的输入是一组观测数据(往往含有较大的噪声或无效点),一个用于解释观测数据的参数化模型以及一些可信的参数。RANSAC通过反复选择数据中的一组随机子集来达成目标。被选取的子集被假设为局内点,并用下述方法进行验证: 

  • 有一个模型适应于假设的局内点,即所有的未知参数都能从假设的局内点计算得出。
  • 用1中得到的模型去测试所有的其它数据,如果某个点适用于估计的模型,认为它也是局内点。
  • 如果有足够多的点被归类为假设的局内点,那么估计的模型就足够合理。
  • 然后,用所有假设的局内点去重新估计模型(譬如使用最小二乘法),因为它仅仅被初始的假设局内点估计过。
  • 最后,通过估计局内点与模型的错误率来评估模型。
  • 上述过程被重复执行固定的次数,每次产生的模型要么因为局内点太少而被舍弃,要么因为比现有的模型更好而被选用。



整个过程可参考下图: 


3.代码实现

随机一致性采样RANSAC是一种鲁棒的模型拟合算法,能够从有外点的数据中拟合准确的模型。


RANSAC过程中用到的参数

N-- 拟合模型所需要的最少的样本个数

K--算法的迭代次数

t--用于判断数据是否是内点

d--判定模型是否符合使用于数据集,也就是判断是否是好的模型


RANSAC算法过程

1  for K 次迭代

2     从数据中均匀随机采样N个点

3     利用采样的N个点拟合你个模型

4     for 对于除采样点外的每一个样本点

5          利用t检测样本点到模型的距离,如果小于t则认为是一致,否则认为是外点

6     end

7     如果有d或者更多的一致点,则认为拟合的模型是好的

8 end

9 使用拟合误差作为标准,选择最好的拟合模型



迭代次数的计算

假设 r = 内点个数/所有点的个数

 则:

   p0 = pow(r, N) 表示采样的N个点全为内点,也就是是一次有效采样的概率

   p1 = 1 - pow(r, N) 表示采样的N个点中至少有一个外点,即一次无效采样的概率

   p2 = pow(p1, K) 表示K次无效采样的概率

假设p表示K次采样中至少一次采样是有效采样,则有1-p = pow(p1, K), 两边取对数

则有 K = log(1- p )/log(1-p1).


 附一份来自google 的RANSAC的代码框架


[cpp] view plaincopyprint?
  1. #ifndef FVISION_RANSAC_H_  
  2. #define FVISION_RANSAC_H_  
  3.   
  4. #include <fvision/utils/random_utils.h>  
  5. #include <fvision/utils/misc.h>  
  6.   
  7. #include <vector>  
  8. #include <iostream>  
  9. #include <cassert>  
  10.   
  11. namespace fvision {  
  12.   
  13. class RANSAC_SamplesNumber {  
  14. public:  
  15.         RANSAC_SamplesNumber(int modelSampleSize) {  
  16.                 this->s = modelSampleSize;  
  17.                 this->p = 0.99;  
  18.         }  
  19.         ~RANSAC_SamplesNumber(void) {}  
  20.   
  21. public:  
  22.         long calcN(int inliersNumber, int samplesNumber) {  
  23.                 double e = 1 - (double)inliersNumber / samplesNumber;  
  24.                 //cout<<"e: "<<e<<endl;  
  25.                 if (e > 0.9) e = 0.9;  
  26.                 //cout<<"pow: "<<pow((1 - e), s)<<endl;  
  27.                 //cout<<log(1 - pow((1 - e), s))<<endl;  
  28.                 long N = (long)(log(1 - p) / log(1 - pow((1 - e), s)));  
  29.                 if (N < 0) return (long)1000000000;  
  30.                 else return N;  
  31.         }  
  32.   
  33. private:                  
  34.         int s;      //samples size for fitting a model  
  35.         double p;   //probability that at least one of the random samples if free from outliers  
  36.                     //usually 0.99  
  37. };  
  38.   
  39. //fit a model to a set of samples  
  40. template <typename M, typename S>  
  41. class GenericModelCalculator {  
  42. public:  
  43.         typedef std::vector<S> Samples;  
  44.         virtual M compute(const Samples& samples) = 0;  
  45.   
  46.         virtual ~GenericModelCalculator<M, S>() {}  
  47.   
  48.         //the model calculator may only use a subset of the samples for computing  
  49.         //default return empty for both  
  50.         virtual const std::vector<int>& getInlierIndices() const { return defaultInlierIndices; };  
  51.         virtual const std::vector<int>& getOutlierIndices() const { return defaultOutlierIndices; };  
  52.   
  53.         // if the subclass has a threshold parameter, it need to override the following three functions  
  54.         // this is used for algorithms which have a normalization step on input samples  
  55.         virtual bool hasThreshold() const { return false; }  
  56.         virtual void setThreshold(double threshold) {}  
  57.         virtual double getThreshold() const { return 0; }  
  58.   
  59. protected:  
  60.         std::vector<int> defaultInlierIndices;  
  61.         std::vector<int> defaultOutlierIndices;  
  62. };  
  63.   
  64. //evaluate a model to samples  
  65. //using a threshold to distinguish inliers and outliers  
  66. template <typename M, typename S>  
  67. class GenericErrorCaclculator {  
  68. public:  
  69.         virtual ~GenericErrorCaclculator<M, S>() {}  
  70.   
  71.         typedef std::vector<S> Samples;  
  72.   
  73.         virtual double compute(const M& model, const S& sample) const = 0;  
  74.   
  75.         double computeAverage(const M& model, const Samples& samples) const {  
  76.                 int n = (int)samples.size();  
  77.                 if (n == 0) return 0;  
  78.                 double sum = 0;  
  79.                 for (int i = 0; i < n; i++) {  
  80.                         sum += compute(model, samples[i]);  
  81.                 }  
  82.                 return sum / n;  
  83.         }  
  84.   
  85.         double computeInlierAverage(const M& model, const Samples& samples) const {  
  86.                 int n = (int)samples.size();  
  87.                 if (n == 0) return 0;  
  88.                 double sum = 0;  
  89.                 double error = 0;  
  90.                 int inlierNum = 0;  
  91.                 for (int i = 0; i < n; i++) {  
  92.                         error = compute(model, samples[i]);  
  93.                         if (error <= threshold) {  
  94.                                 sum += error;  
  95.                                 inlierNum++;  
  96.                         }  
  97.                 }  
  98.                 if (inlierNum == 0) return 1000000;  
  99.                 return sum / inlierNum;  
  100.         }  
  101.   
  102. public:  
  103.   
  104.         /** set a threshold for classify inliers and outliers 
  105.          */  
  106.         void setThreshold(double v) { threshold = v; }  
  107.   
  108.         double getThreshold() const { return threshold; }  
  109.   
  110.         /** classify all samples to inliers and outliers 
  111.          *  
  112.          */  
  113.         void classify(const M& model, const Samples& samples, Samples& inliers, Samples& outliers) const {  
  114.                 inliers.clear();  
  115.                 outliers.clear();  
  116.                 Samples::const_iterator iter = samples.begin();  
  117.                 for (; iter != samples.end(); ++iter) {  
  118.                         if (isInlier(model, *iter)) inliers.push_back(*iter);  
  119.                         else outliers.push_back(*iter);  
  120.                 }  
  121.         }  
  122.   
  123.         /** classify all samples to inliers and outliers, output indices 
  124.          *  
  125.          */  
  126.         void classify(const M& model, const Samples& samples, std::vector<int>& inlierIndices, std::vector<int>& outlierIndices) const {  
  127.                 inlierIndices.clear();  
  128.                 outlierIndices.clear();  
  129.                 Samples::const_iterator iter = samples.begin();  
  130.                 int i = 0;  
  131.                 for (; iter != samples.end(); ++iter, ++i) {  
  132.                         if (isInlier(model, *iter)) inlierIndices.push_back(i);  
  133.                         else outlierIndices.push_back(i);  
  134.                 }  
  135.         }  
  136.   
  137.         /** classify all samples to inliers and outliers 
  138.          *  
  139.          */  
  140.         void classify(const M& model, const Samples& samples,   
  141.                 std::vector<int>& inlierIndices, std::vector<int>& outlierIndices,   
  142.                 Samples& inliers, Samples& outliers) const {  
  143.   
  144.                 inliers.clear();  
  145.                 outliers.clear();  
  146.                 inlierIndices.clear();  
  147.                 outlierIndices.clear();  
  148.                 Samples::const_iterator iter = samples.begin();  
  149.                 int i = 0;  
  150.                 for (; iter != samples.end(); ++iter, ++i) {  
  151.                         if (isInlier(model, *iter)) {  
  152.                                 inliers.push_back(*iter);  
  153.                                 inlierIndices.push_back(i);  
  154.                         }  
  155.                         else {  
  156.                                 outliers.push_back(*iter);  
  157.                                 outlierIndices.push_back(i);  
  158.                         }  
  159.                 }  
  160.         }  
  161.   
  162.         int calcInliersNumber(const M& model, const Samples& samples) const {  
  163.                 int n = 0;  
  164.                 for (int i = 0; i < (int)samples.size(); i++) {  
  165.                         if (isInlier(model, samples[i])) ++n;  
  166.                 }  
  167.                 return n;  
  168.         }  
  169.   
  170.         bool isInlier(const M& model, const S& sample) const {  
  171.                 return (compute(model, sample) <= threshold);  
  172.         }  
  173.   
  174. private:  
  175.         double threshold;  
  176. };  
  177.   
  178. /** generic RANSAC framework 
  179.  * make use of a model calculator and an error calculator 
  180.  * M is the model type, need to support copy assignment operator and default constructor 
  181.  * S is the sample type. 
  182.  * 
  183.  * Interface: 
  184.  *  M compute(samples); input a set of samples, output a model.  
  185.  *  after compute, inliers and outliers can be retrieved 
  186.  *  
  187.  */  
  188. template <typename M, typename S>  
  189. class Ransac : public GenericModelCalculator<M, S> {  
  190. public:  
  191.         typedef std::vector<S> Samples;  
  192.   
  193.         /** Constructor 
  194.          *  
  195.          * @param pmc a GenericModelCalculator object 
  196.          * @param modelSampleSize how much samples are used to fit a model 
  197.          * @param pec a GenericErrorCaclculator object 
  198.          */  
  199.         Ransac(GenericModelCalculator<M, S>* pmc, int modelSampleSize, GenericErrorCaclculator<M, S>* pec) {  
  200.                 this->pmc = pmc;  
  201.                 this->modelSampleSize = modelSampleSize;  
  202.                 this->pec = pec;  
  203.                 this->maxSampleCount = 500;  
  204.                 this->minInliersNum = 1000000;  
  205.   
  206.                 this->verbose = false;  
  207.         }  
  208.   
  209.         const GenericErrorCaclculator<M, S>* getErrorCalculator() const { return pec; }  
  210.   
  211.         virtual ~Ransac() {  
  212.                 delete pmc;  
  213.                 delete pec;  
  214.         }  
  215.   
  216.         void setMaxSampleCount(int n) {  
  217.                 this->maxSampleCount = n;  
  218.         }  
  219.   
  220.         void setMinInliersNum(int n) {  
  221.                 this->minInliersNum = n;  
  222.         }  
  223.   
  224.         virtual bool hasThreshold() const { return true; }  
  225.   
  226.         virtual void setThreshold(double threshold) {  
  227.                 pec->setThreshold(threshold);  
  228.         }  
  229.   
  230.         virtual double getThreshold() const {  
  231.                 return pec->getThreshold();  
  232.         }  
  233.   
  234. public:  
  235.         /** Given samples, compute a model that has most inliers. Assume the samples size is larger or equal than model sample size 
  236.          * inliers, outliers, inlierIndices and outlierIndices are stored 
  237.          *  
  238.          */  
  239.         M compute(const Samples& samples) {  
  240.                 clear();  
  241.   
  242.                 int pointsNumber = (int)samples.size();  
  243.   
  244.                 assert(pointsNumber >= modelSampleSize);  
  245.   
  246.                 long N = 100000;  
  247.                 int sampleCount = 0;  
  248.                 RANSAC_SamplesNumber ransac(modelSampleSize);  
  249.   
  250.                 M bestModel;  
  251.                 int maxInliersNumber = 0;  
  252.   
  253.                 bool stop = false;  
  254.                 while (sampleCount < N && sampleCount < maxSampleCount && !stop) {  
  255.   
  256.                         Samples nsamples;  
  257.                         randomlySampleN(samples, nsamples, modelSampleSize);  
  258.   
  259.                         M sampleModel = pmc->compute(nsamples);  
  260.                         if (maxInliersNumber == 0) bestModel = sampleModel;  //init bestModel  
  261.   
  262.                         int inliersNumber = pec->calcInliersNumber(sampleModel, samples);  
  263.                         if (verbose) std::cout<<"inliers number: "<<inliersNumber<<std::endl;  
  264.   
  265.                         if (inliersNumber > maxInliersNumber) {  
  266.                                 bestModel = sampleModel;  
  267.                                 maxInliersNumber = inliersNumber;  
  268.                                 N = ransac.calcN(inliersNumber, pointsNumber);  
  269.                                 if (maxInliersNumber > minInliersNum) stop = true;  
  270.                         }  
  271.   
  272.                         if (verbose) std::cout<<"N: "<<N<<std::endl;  
  273.   
  274.                         sampleCount ++;  
  275.                 }  
  276.   
  277.                 if (verbose) std::cout<<"sampleCount: "<<sampleCount<<std::endl;  
  278.   
  279.                 finalModel = computeUntilConverge(bestModel, maxInliersNumber, samples);  
  280.                   
  281.                 pec->classify(finalModel, samples, inlierIndices, outlierIndices, inliers, outliers);  
  282.   
  283.                 inliersRate = (double)inliers.size() / samples.size();  
  284.   
  285.                 return finalModel;  
  286.         }  
  287.   
  288.         const Samples& getInliers() const { return inliers; }  
  289.         const Samples& getOutliers() const { return outliers; }  
  290.   
  291.         const std::vector<int>& getInlierIndices() const { return inlierIndices; }  
  292.         const std::vector<int>& getOutlierIndices() const { return outlierIndices; }  
  293.   
  294.         double getInliersAverageError() const {  
  295.                 return pec->computeAverage(finalModel, inliers);  
  296.         }  
  297.   
  298.         double getInliersRate() const {  
  299.                 return inliersRate;  
  300.         }  
  301.   
  302.         void setVerbose(bool v) {  
  303.                 verbose = v;  
  304.         }  
  305.   
  306. private:  
  307.         void randomlySampleN(const Samples& samples, Samples& nsamples, int sampleSize) {  
  308.                 std::vector<int> is = ranis((int)samples.size(), sampleSize);  
  309.                 for (int i = 0; i < sampleSize; i++) {  
  310.                         nsamples.push_back(samples[is[i]]);  
  311.                 }  
  312.         }  
  313.   
  314.         /** from initial model, iterate to find the best model. 
  315.          * 
  316.          */  
  317.         M computeUntilConverge(M initModel, int initInliersNum, const Samples& samples) {  
  318.                 if (verbose) {  
  319.                         std::cout<<"iterate until converge...."<<std::endl;  
  320.                         std::cout<<"init inliers number: "<<initInliersNum<<std::endl;  
  321.                 }  
  322.   
  323.                 M bestModel = initModel;  
  324.                 M newModel = initModel;  
  325.   
  326.                 int lastInliersNum = initInliersNum;  
  327.   
  328.                 Samples newInliers, newOutliers;  
  329.                 pec->classify(initModel, samples, newInliers, newOutliers);  
  330.                 double lastInlierAverageError = pec->computeAverage(initModel, newInliers);  
  331.   
  332.                 if (verbose) std::cout<<"init inlier average error: "<<lastInlierAverageError<<std::endl;  
  333.   
  334.                 while (true && (int)newInliers.size() >= modelSampleSize) {  
  335.   
  336.                         //update new model with new inliers, the new model does not necessarily have more inliers  
  337.                         newModel = pmc->compute(newInliers);  
  338.   
  339.                         pec->classify(newModel, samples, newInliers, newOutliers);  
  340.   
  341.                         int newInliersNum = (int)newInliers.size();  
  342.                         double newInlierAverageError = pec->computeAverage(newModel, newInliers);  
  343.   
  344.                         if (verbose) {  
  345.                                 std::cout<<"new inliers number: "<<newInliersNum<<std::endl;  
  346.                                 std::cout<<"new inlier average error: "<<newInlierAverageError<<std::endl;  
  347.                         }  
  348.                         if (newInliersNum < lastInliersNum) break;  
  349.                         if (newInliersNum == lastInliersNum && newInlierAverageError >= lastInlierAverageError) break;  
  350.   
  351.                         //update best model with the model has more inliers  
  352.                         bestModel = newModel;  
  353.   
  354.                         lastInliersNum = newInliersNum;  
  355.                         lastInlierAverageError = newInlierAverageError;  
  356.                 }  
  357.   
  358.                 return bestModel;  
  359.         }  
  360.   
  361.         void clear() {  
  362.                 inliers.clear();  
  363.                 outliers.clear();  
  364.                 inlierIndices.clear();  
  365.                 outlierIndices.clear();  
  366.         }  
  367.   
  368. private:  
  369.         GenericModelCalculator<M, S>* pmc;  
  370.         GenericErrorCaclculator<M, S>* pec;  
  371.         int modelSampleSize;  
  372.   
  373.         int maxSampleCount;  
  374.         int minInliersNum;  
  375.   
  376.         M finalModel;  
  377.   
  378.         Samples inliers;  
  379.         Samples outliers;  
  380.   
  381.         std::vector<int> inlierIndices;  
  382.         std::vector<int> outlierIndices;  
  383.   
  384.         double inliersRate;  
  385.   
  386. private:  
  387.         bool verbose;  
  388.   
  389. };  
  390.   
  391. }  
  392. #endif // FVISION_RANSAC_H_  

实例2

  1. #include <math.h>  
  2. #include "LineParamEstimator.h"  
  3.   
  4. LineParamEstimator::LineParamEstimator(double delta) : m_deltaSquared(delta*delta) {}  
  5. /*****************************************************************************/  
  6. /* 
  7.  * Compute the line parameters  [n_x,n_y,a_x,a_y] 
  8.  * 通过输入的两点来确定所在直线,采用法线向量的方式来表示,以兼容平行或垂直的情况 
  9.  * 其中n_x,n_y为归一化后,与原点构成的法线向量,a_x,a_y为直线上任意一点 
  10.  */  
  11. void LineParamEstimator::estimate(std::vector<Point2D *> &data,   
  12.                                                                     std::vector<double> ¶meters)  
  13. {  
  14.     parameters.clear();  
  15.     if(data.size()<2)  
  16.         return;  
  17.     double nx = data[1]->y - data[0]->y;  
  18.     double ny = data[0]->x - data[1]->x;// 原始直线的斜率为K,则法线的斜率为-1/k  
  19.     double norm = sqrt(nx*nx + ny*ny);  
  20.       
  21.     parameters.push_back(nx/norm);  
  22.     parameters.push_back(ny/norm);  
  23.     parameters.push_back(data[0]->x);  
  24.     parameters.push_back(data[0]->y);          
  25. }  
  26. /*****************************************************************************/  
  27. /* 
  28.  * Compute the line parameters  [n_x,n_y,a_x,a_y] 
  29.  * 使用最小二乘法,从输入点中拟合出确定直线模型的所需参量 
  30.  */  
  31. void LineParamEstimator::leastSquaresEstimate(std::vector<Point2D *> &data,   
  32.                                                                                             std::vector<double> ¶meters)  
  33. {  
  34.     double meanX, meanY, nx, ny, norm;  
  35.     double covMat11, covMat12, covMat21, covMat22; // The entries of the symmetric covarinace matrix  
  36.     int i, dataSize = data.size();  
  37.   
  38.     parameters.clear();  
  39.     if(data.size()<2)  
  40.         return;  
  41.   
  42.     meanX = meanY = 0.0;  
  43.     covMat11 = covMat12 = covMat21 = covMat22 = 0;  
  44.     for(i=0; i<dataSize; i++) {  
  45.         meanX +=data[i]->x;  
  46.         meanY +=data[i]->y;  
  47.   
  48.         covMat11    +=data[i]->x * data[i]->x;  
  49.         covMat12    +=data[i]->x * data[i]->y;  
  50.         covMat22    +=data[i]->y * data[i]->y;  
  51.     }  
  52.   
  53.     meanX/=dataSize;  
  54.     meanY/=dataSize;  
  55.   
  56.     covMat11 -= dataSize*meanX*meanX;  
  57.         covMat12 -= dataSize*meanX*meanY;  
  58.     covMat22 -= dataSize*meanY*meanY;  
  59.     covMat21 = covMat12;  
  60.   
  61.     if(covMat11<1e-12) {  
  62.         nx = 1.0;  
  63.             ny = 0.0;  
  64.     }  
  65.     else {      //lamda1 is the largest eigen-value of the covariance matrix   
  66.                //and is used to compute the eigne-vector corresponding to the smallest  
  67.                //eigenvalue, which isn't computed explicitly.  
  68.         double lamda1 = (covMat11 + covMat22 + sqrt((covMat11-covMat22)*(covMat11-covMat22) + 4*covMat12*covMat12)) / 2.0;  
  69.         nx = -covMat12;  
  70.         ny = lamda1 - covMat22;  
  71.         norm = sqrt(nx*nx + ny*ny);  
  72.         nx/=norm;  
  73.         ny/=norm;  
  74.     }  
  75.     parameters.push_back(nx);  
  76.     parameters.push_back(ny);  
  77.     parameters.push_back(meanX);  
  78.     parameters.push_back(meanY);  
  79. }  
  80. /*****************************************************************************/  
  81. /* 
  82.  * Given the line parameters  [n_x,n_y,a_x,a_y] check if 
  83.  * [n_x, n_y] dot [data.x-a_x, data.y-a_y] < m_delta 
  84.  * 通过与已知法线的点乘结果,确定待测点与已知直线的匹配程度;结果越小则越符合,为 
  85.  * 零则表明点在直线上 
  86.  */  
  87. bool LineParamEstimator::agree(std::vector<double> ¶meters, Point2D &data)  
  88. {  
  89.     double signedDistance = parameters[0]*(data.x-parameters[2]) + parameters[1]*(data.y-parameters[3]);   
  90.     return ((signedDistance*signedDistance) < m_deltaSquared);  
  91. }  


RANSAC寻找匹配的代码如下:

[cpp] view plaincopyprint?在CODE上查看代码片派生到我的代码片
  1. /*****************************************************************************/  
  2. template<class T, class S>  
  3. double Ransac<T,S>::compute(std::vector<S> ¶meters,   
  4.                                                       ParameterEsitmator<T,S> *paramEstimator ,   
  5.                                                     std::vector<T> &data,   
  6.                                                     int numForEstimate)  
  7. {  
  8.     std::vector<T *> leastSquaresEstimateData;  
  9.     int numDataObjects = data.size();  
  10.     int numVotesForBest = -1;  
  11.     int *arr = new int[numForEstimate];// numForEstimate表示拟合模型所需要的最少点数,对本例的直线来说,该值为2  
  12.     short *curVotes = new short[numDataObjects];  //one if data[i] agrees with the current model, otherwise zero  
  13.     short *bestVotes = new short[numDataObjects];  //one if data[i] agrees with the best model, otherwise zero  
  14.       
  15.   
  16.               //there are less data objects than the minimum required for an exact fit  
  17.     if(numDataObjects < numForEstimate)   
  18.         return 0;  
  19.         // 计算所有可能的直线,寻找其中误差最小的解。对于100点的直线拟合来说,大约需要100*99*0.5=4950次运算,复杂度无疑是庞大的。一般采用随机选取子集的方式。  
  20.     computeAllChoices(paramEstimator,data,numForEstimate,  
  21.                                         bestVotes, curVotes, numVotesForBest, 0, data.size(), numForEstimate, 0, arr);  
  22.   
  23.        //compute the least squares estimate using the largest sub set  
  24.     for(int j=0; j<numDataObjects; j++) {  
  25.         if(bestVotes[j])  
  26.             leastSquaresEstimateData.push_back(&(data[j]));  
  27.     }  
  28.         // 对局内点再次用最小二乘法拟合出模型  
  29.     paramEstimator->leastSquaresEstimate(leastSquaresEstimateData,parameters);  
  30.   
  31.     delete [] arr;  
  32.     delete [] bestVotes;  
  33.     delete [] curVotes;   
  34.   
  35.     return (double)leastSquaresEstimateData.size()/(double)numDataObjects;  
  36. }  
1 0
原创粉丝点击