C++最小二乘法拟合-(线性拟合和多项式拟合)

来源:互联网 发布:网络视频策划制作 编辑:程序博客网 时间:2024/04/28 11:07

在进行曲线拟合时用的最多的是最小二乘法,其中以一元函数(线性)和多元函数(多项式)居多,下面这个类专门用于进行多项式拟合,可以根据用户输入的阶次进行多项式拟合,算法来自于网上,和GSL的拟合算法对比过,没有问题。此类在拟合完后还能计算拟合之后的误差:SSE(剩余平方和),SSR(回归平方和),RMSE(均方根误差),R-square(确定系数)。


1.fit类的实现

先看看fit类的代码:(只有一个头文件方便使用)

这是用网上的代码实现的,下面有用GSL实现的版本

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. #ifndef CZY_MATH_FIT  
  2. #define CZY_MATH_FIT  
  3. #include <vector>  
  4. /* 
  5. 尘中远,于2014.03.20 
  6. 主页:http://blog.csdn.net/czyt1988/article/details/21743595 
  7. 参考:http://blog.csdn.net/maozefa/article/details/1725535 
  8. */  
  9. namespace czy{  
  10.     ///  
  11.     /// \brief 曲线拟合类  
  12.     ///  
  13.     class Fit{  
  14.         std::vector<double> factor; ///<拟合后的方程系数  
  15.         double ssr;                 ///<回归平方和  
  16.         double sse;                 ///<(剩余平方和)  
  17.         double rmse;                ///<RMSE均方根误差  
  18.         std::vector<double> fitedYs;///<存放拟合后的y值,在拟合时可设置为不保存节省内存  
  19.     public:  
  20.         Fit():ssr(0),sse(0),rmse(0){factor.resize(2,0);}  
  21.         ~Fit(){}  
  22.         ///  
  23.         /// \brief 直线拟合-一元回归,拟合的结果可以使用getFactor获取,或者使用getSlope获取斜率,getIntercept获取截距  
  24.         /// \param x 观察值的x  
  25.         /// \param y 观察值的y  
  26.         /// \param isSaveFitYs 拟合后的数据是否保存,默认否  
  27.         ///  
  28.         template<typename T>  
  29.         bool linearFit(const std::vector<typename T>& x, const std::vector<typename T>& y,bool isSaveFitYs=false)  
  30.         {  
  31.             return linearFit(&x[0],&y[0],getSeriesLength(x,y),isSaveFitYs);  
  32.         }  
  33.         template<typename T>  
  34.         bool linearFit(const T* x, const T* y,size_t length,bool isSaveFitYs=false)  
  35.         {  
  36.             factor.resize(2,0);  
  37.             typename T t1=0, t2=0, t3=0, t4=0;  
  38.             for(int i=0; i<length; ++i)  
  39.             {  
  40.                 t1 += x[i]*x[i];  
  41.                 t2 += x[i];  
  42.                 t3 += x[i]*y[i];  
  43.                 t4 += y[i];  
  44.             }  
  45.             factor[1] = (t3*length - t2*t4) / (t1*length - t2*t2);  
  46.             factor[0] = (t1*t4 - t2*t3) / (t1*length - t2*t2);  
  47.             //////////////////////////////////////////////////////////////////////////  
  48.             //计算误差  
  49.             calcError(x,y,length,this->ssr,this->sse,this->rmse,isSaveFitYs);  
  50.             return true;  
  51.         }  
  52.         ///  
  53.         /// \brief 多项式拟合,拟合y=a0+a1*x+a2*x^2+……+apoly_n*x^poly_n  
  54.         /// \param x 观察值的x  
  55.         /// \param y 观察值的y  
  56.         /// \param poly_n 期望拟合的阶数,若poly_n=2,则y=a0+a1*x+a2*x^2  
  57.         /// \param isSaveFitYs 拟合后的数据是否保存,默认是  
  58.         ///   
  59.         template<typename T>  
  60.         void polyfit(const std::vector<typename T>& x  
  61.             ,const std::vector<typename T>& y  
  62.             ,int poly_n  
  63.             ,bool isSaveFitYs=true)  
  64.         {  
  65.             polyfit(&x[0],&y[0],getSeriesLength(x,y),poly_n,isSaveFitYs);  
  66.         }  
  67.         template<typename T>  
  68.         void polyfit(const T* x,const T* y,size_t length,int poly_n,bool isSaveFitYs=true)  
  69.         {  
  70.             factor.resize(poly_n+1,0);  
  71.             int i,j;  
  72.             //double *tempx,*tempy,*sumxx,*sumxy,*ata;  
  73.             std::vector<double> tempx(length,1.0);  
  74.   
  75.             std::vector<double> tempy(y,y+length);  
  76.   
  77.             std::vector<double> sumxx(poly_n*2+1);  
  78.             std::vector<double> ata((poly_n+1)*(poly_n+1));  
  79.             std::vector<double> sumxy(poly_n+1);  
  80.             for (i=0;i<2*poly_n+1;i++){  
  81.                 for (sumxx[i]=0,j=0;j<length;j++)  
  82.                 {  
  83.                     sumxx[i]+=tempx[j];  
  84.                     tempx[j]*=x[j];  
  85.                 }  
  86.             }  
  87.             for (i=0;i<poly_n+1;i++){  
  88.                 for (sumxy[i]=0,j=0;j<length;j++)  
  89.                 {  
  90.                     sumxy[i]+=tempy[j];  
  91.                     tempy[j]*=x[j];  
  92.                 }  
  93.             }  
  94.             for (i=0;i<poly_n+1;i++)  
  95.                 for (j=0;j<poly_n+1;j++)  
  96.                     ata[i*(poly_n+1)+j]=sumxx[i+j];  
  97.             gauss_solve(poly_n+1,ata,factor,sumxy);  
  98.             //计算拟合后的数据并计算误差  
  99.             fitedYs.reserve(length);  
  100.             calcError(&x[0],&y[0],length,this->ssr,this->sse,this->rmse,isSaveFitYs);  
  101.   
  102.         }  
  103.         ///   
  104.         /// \brief 获取系数  
  105.         /// \param 存放系数的数组  
  106.         ///  
  107.         void getFactor(std::vector<double>& factor){factor = this->factor;}  
  108.         ///   
  109.         /// \brief 获取拟合方程对应的y值,前提是拟合时设置isSaveFitYs为true  
  110.         ///  
  111.         void getFitedYs(std::vector<double>& fitedYs){fitedYs = this->fitedYs;}  
  112.   
  113.         ///   
  114.         /// \brief 根据x获取拟合方程的y值  
  115.         /// \return 返回x对应的y值  
  116.         ///  
  117.         template<typename T>  
  118.         double getY(const T x) const  
  119.         {  
  120.             double ans(0);  
  121.             for (size_t i=0;i<factor.size();++i)  
  122.             {  
  123.                 ans += factor[i]*pow((double)x,(int)i);  
  124.             }  
  125.             return ans;  
  126.         }  
  127.         ///   
  128.         /// \brief 获取斜率  
  129.         /// \return 斜率值  
  130.         ///  
  131.         double getSlope(){return factor[1];}  
  132.         ///   
  133.         /// \brief 获取截距  
  134.         /// \return 截距值  
  135.         ///  
  136.         double getIntercept(){return factor[0];}  
  137.         ///   
  138.         /// \brief 剩余平方和  
  139.         /// \return 剩余平方和  
  140.         ///  
  141.         double getSSE(){return sse;}  
  142.         ///   
  143.         /// \brief 回归平方和  
  144.         /// \return 回归平方和  
  145.         ///  
  146.         double getSSR(){return ssr;}  
  147.         ///   
  148.         /// \brief 均方根误差  
  149.         /// \return 均方根误差  
  150.         ///  
  151.         double getRMSE(){return rmse;}  
  152.         ///   
  153.         /// \brief 确定系数,系数是0~1之间的数,是数理上判定拟合优度的一个量  
  154.         /// \return 确定系数  
  155.         ///  
  156.         double getR_square(){return 1-(sse/(ssr+sse));}  
  157.         ///   
  158.         /// \brief 获取两个vector的安全size  
  159.         /// \return 最小的一个长度  
  160.         ///  
  161.         template<typename T>  
  162.         size_t getSeriesLength(const std::vector<typename T>& x  
  163.             ,const std::vector<typename T>& y)  
  164.         {  
  165.             return (x.size() > y.size() ? y.size() : x.size());  
  166.         }  
  167.         ///   
  168.         /// \brief 计算均值  
  169.         /// \return 均值  
  170.         ///  
  171.         template <typename T>  
  172.         static T Mean(const std::vector<T>& v)  
  173.         {  
  174.             return Mean(&v[0],v.size());  
  175.         }  
  176.         template <typename T>  
  177.         static T Mean(const T* v,size_t length)  
  178.         {  
  179.             T total(0);  
  180.             for (size_t i=0;i<length;++i)  
  181.             {  
  182.                 total += v[i];  
  183.             }  
  184.             return (total / length);  
  185.         }  
  186.         ///   
  187.         /// \brief 获取拟合方程系数的个数  
  188.         /// \return 拟合方程系数的个数  
  189.         ///  
  190.         size_t getFactorSize(){return factor.size();}  
  191.         ///   
  192.         /// \brief 根据阶次获取拟合方程的系数,  
  193.         /// 如getFactor(2),就是获取y=a0+a1*x+a2*x^2+……+apoly_n*x^poly_n中a2的值  
  194.         /// \return 拟合方程的系数  
  195.         ///  
  196.         double getFactor(size_t i){return factor.at(i);}  
  197.     private:  
  198.         template<typename T>  
  199.         void calcError(const T* x  
  200.             ,const T* y  
  201.             ,size_t length  
  202.             ,double& r_ssr  
  203.             ,double& r_sse  
  204.             ,double& r_rmse  
  205.             ,bool isSaveFitYs=true  
  206.             )  
  207.         {  
  208.             T mean_y = Mean<T>(y,length);  
  209.             T yi(0);  
  210.             fitedYs.reserve(length);  
  211.             for (int i=0; i<length; ++i)  
  212.             {  
  213.                 yi = getY(x[i]);  
  214.                 r_ssr += ((yi-mean_y)*(yi-mean_y));//计算回归平方和  
  215.                 r_sse += ((yi-y[i])*(yi-y[i]));//残差平方和  
  216.                 if (isSaveFitYs)  
  217.                 {  
  218.                     fitedYs.push_back(double(yi));  
  219.                 }  
  220.             }  
  221.             r_rmse = sqrt(r_sse/(double(length)));  
  222.         }  
  223.         template<typename T>  
  224.         void gauss_solve(int n  
  225.             ,std::vector<typename T>& A  
  226.             ,std::vector<typename T>& x  
  227.             ,std::vector<typename T>& b)  
  228.         {  
  229.             gauss_solve(n,&A[0],&x[0],&b[0]);     
  230.         }  
  231.         template<typename T>  
  232.         void gauss_solve(int n  
  233.             ,T* A  
  234.             ,T* x  
  235.             ,T* b)  
  236.         {  
  237.             int i,j,k,r;  
  238.             double max;  
  239.             for (k=0;k<n-1;k++)  
  240.             {  
  241.                 max=fabs(A[k*n+k]); /*find maxmum*/  
  242.                 r=k;  
  243.                 for (i=k+1;i<n-1;i++){  
  244.                     if (max<fabs(A[i*n+i]))  
  245.                     {  
  246.                         max=fabs(A[i*n+i]);  
  247.                         r=i;  
  248.                     }  
  249.                 }  
  250.                 if (r!=k){  
  251.                     for (i=0;i<n;i++)         /*change array:A[k]&A[r] */  
  252.                     {  
  253.                         max=A[k*n+i];  
  254.                         A[k*n+i]=A[r*n+i];  
  255.                         A[r*n+i]=max;  
  256.                     }  
  257.                 }  
  258.                 max=b[k];                    /*change array:b[k]&b[r]     */  
  259.                 b[k]=b[r];  
  260.                 b[r]=max;  
  261.                 for (i=k+1;i<n;i++)  
  262.                 {  
  263.                     for (j=k+1;j<n;j++)  
  264.                         A[i*n+j]-=A[i*n+k]*A[k*n+j]/A[k*n+k];  
  265.                     b[i]-=A[i*n+k]*b[k]/A[k*n+k];  
  266.                 }  
  267.             }   
  268.   
  269.             for (i=n-1;i>=0;x[i]/=A[i*n+i],i--)  
  270.                 for (j=i+1,x[i]=b[i];j<n;j++)  
  271.                     x[i]-=A[i*n+j]*x[j];  
  272.         }  
  273.     };  
  274. }  
  275.   
  276.   
  277. #endif  

GSL实现版本,此版本依赖于GSL需要先配置GSL,GSL配置方法网上很多,我的blog也有一篇介绍win + Qt环境下的配置,其它大同小异:http://blog.csdn.NET/czyt1988/article/details/39178975


[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. #ifndef CZYMATH_FIT_H  
  2. #define CZYMATH_FIT_H  
  3.   
  4. #include <czyMath.h>  
  5. namespace gsl{  
  6.     #include <gsl/gsl_fit.h>  
  7.     #include <gsl/gsl_cdf.h>    /* 提供了 gammaq 函数 */  
  8.     #include <gsl/gsl_vector.h> /* 提供了向量结构*/  
  9.     #include <gsl/gsl_matrix.h>  
  10.     #include <gsl/gsl_multifit.h>  
  11. }  
  12.   
  13. namespace czy {  
  14. ///  
  15. /// \brief The Math class 用于处理简单数学计算  
  16. ///  
  17.     namespace Math{  
  18.         using namespace gsl;  
  19.     ///  
  20.     /// \brief 拟合类,封装了gsl的拟合算法  
  21.     ///  
  22.     /// 实现线性拟合和多项式拟合  
  23.     ///  
  24.         class fit{  
  25.         public:  
  26.             fit(){}  
  27.             ~fit(){}  
  28.         private:  
  29.             std::map<double,double> m_factor;//记录各个点的系数,key中0是0次方,1是1次方,value是对应的系数  
  30.             std::map<double,double> m_err;  
  31.             double m_cov;//相关度  
  32.             double m_ssr;//回归平方和  
  33.             double m_sse;//(剩余平方和)  
  34.             double m_rmse;//RMSE均方根误差  
  35.             double m_wssr;  
  36.             double m_goodness;//基于wssr的拟合优度  
  37.             void clearAll(){  
  38.                 m_factor.clear();m_err.clear();  
  39.             }  
  40.         public:  
  41.             //计算拟合的显著性  
  42.             static  void getDeterminateOfCoefficient(  
  43.                 const double* y,const double* yi,size_t length  
  44.                 ,double& out_ssr,double& out_sse,double& out_sst,double& out_rmse,double& out_RSquare)  
  45.             {  
  46.                 double y_mean = mean(y,y+length);  
  47.                 out_ssr = 0.0;  
  48.                 for (size_t i =0;i<length;++i)  
  49.                 {  
  50.                     out_ssr += ((yi[i]-y_mean)*(yi[i]-y_mean));  
  51.                     out_sse += ((y[i] - yi[i])*(y[i] - yi[i]));  
  52.                 }  
  53.                 out_sst = out_ssr + out_sse;  
  54.                 out_rmse = sqrt(out_sse/(double(length)));  
  55.                 out_RSquare = out_ssr/out_sst;  
  56.             }  
  57.             ///  
  58.             /// \brief 获取拟合的系数  
  59.             /// \param n 0是0次方,1是1次方,value是对应的系数  
  60.             /// \return 次幂对应的系数  
  61.             ///  
  62.             double getFactor(double n)  
  63.             {  
  64.                 auto ite = m_factor.find(n);  
  65.                 if (ite == m_factor.end())  
  66.                     return 0.0;  
  67.                 return ite->second;  
  68.             }  
  69.             ///  
  70.             /// \brief 获取系数的个数  
  71.             /// \return  
  72.             ///  
  73.             size_t getFactorSize()  
  74.             {  
  75.                 return m_factor.size();  
  76.             }  
  77.             ///  
  78.             /// \brief linearFit 线性拟合的静态函数  
  79.             /// \param x 数据点的横坐标值数组  
  80.             /// \param xstride 横坐标值数组索引步长 xstride 与 ystride 的值设为 1,表示数据点集 {(xi,yi)|i=0,1,⋯,n−1} 全部参与直线的拟合;  
  81.             /// \param y 数据点的纵坐标值数组  
  82.             /// \param ystride 纵坐标值数组索引步长  
  83.             /// \param n 数据点的数量  
  84.             /// \param out_intercept 计算的截距  
  85.             /// \param out_slope 计算的斜率  
  86.             /// \param out_interceptErr 计算的截距误差  
  87.             /// \param out_slopeErr 计算的斜率误差  
  88.             /// \param out_cov 计算的斜率和截距的相关度  
  89.             /// \param out_wssr 拟合的wssr值  
  90.             /// \return  
  91.             ///  
  92.             static int linearFit(  
  93.                 const double *x  
  94.                 ,const size_t xstride  
  95.                 ,const double *y  
  96.                 ,const size_t ystride  
  97.                 ,size_t n  
  98.                 ,double& out_intercept  
  99.                 ,double& out_slope  
  100.                 ,double& out_interceptErr  
  101.                 ,double& out_slopeErr  
  102.                 ,double& out_cov  
  103.                 ,double& out_wssr  
  104.                 )  
  105.             {  
  106.                 return gsl_fit_linear(x,xstride,y,ystride,n  
  107.                     ,&out_intercept,&out_slope,&out_interceptErr,&out_slopeErr,&out_cov,&out_wssr);  
  108.             }  
  109.             ///  
  110.             /// \brief  线性拟合  
  111.             /// \param x 拟合的x值  
  112.             /// \param y 拟合的y值  
  113.             /// \param n x,y值对应的长度  
  114.             /// \return  
  115.             ///  
  116.             bool linearFit(const double *x,const double *y,size_t n)  
  117.             {  
  118.                 clearAll();  
  119.                 m_factor[0]=0;m_err[0]=0;  
  120.                 m_factor[1]=1;m_err[1]=0;  
  121.                 int r = linearFit(x,1,y,1,n  
  122.                     ,m_factor[0],m_factor[1],m_err[0],m_err[1],m_cov,m_wssr);  
  123.                 if (0 != r)  
  124.                     return false;  
  125.                 m_goodness = gsl_cdf_chisq_Q(m_wssr/2.0,(n-2)/2.0);//计算优度  
  126.                 {  
  127.                     std::vector<double> yi;  
  128.                     getYis(x,n,yi);  
  129.                     double t;  
  130.                     getDeterminateOfCoefficient(y,&yi[0],n,m_ssr,m_sse,t,m_rmse,t);  
  131.                 }  
  132.                 return true;  
  133.             }  
  134.             bool linearFit(const std::vector<double>& x,const std::vector<double>& y)  
  135.             {  
  136.                 size_t n = x.size() > y.size() ? y.size() :x.size();  
  137.                 return linearFit(&x[0],&y[0],n);  
  138.             }  
  139.             ///  
  140.             /// \brief 多项式拟合  
  141.             /// \param poly_n 阶次,如c0+C1x是1,若c0+c1x+c2x^2则poly_n是2  
  142.             static int polyfit(const double *x  
  143.                 ,const double *y  
  144.                 ,size_t xyLength  
  145.                 ,unsigned poly_n  
  146.                 ,std::vector<double>& out_factor  
  147.                 ,double& out_chisq)//拟合曲线与数据点的优值函数最小值 ,χ2 检验  
  148.             {  
  149.                 gsl_matrix *XX = gsl_matrix_alloc(xyLength, poly_n + 1);  
  150.                 gsl_vector *c = gsl_vector_alloc(poly_n + 1);  
  151.                 gsl_matrix *cov = gsl_matrix_alloc(poly_n + 1, poly_n + 1);  
  152.                 gsl_vector *vY = gsl_vector_alloc(xyLength);  
  153.   
  154.                 for(size_t i = 0; i < xyLength; i++)  
  155.                 {  
  156.                     gsl_matrix_set(XX, i, 0, 1.0);  
  157.                     gsl_vector_set (vY, i, y[i]);  
  158.                     for(unsigned j = 1; j <= poly_n; j++)  
  159.                     {  
  160.                         gsl_matrix_set(XX, i, j, pow(x[i], int(j) ));  
  161.                     }  
  162.                 }  
  163.                 gsl_multifit_linear_workspace *workspace = gsl_multifit_linear_alloc(xyLength, poly_n + 1);  
  164.                 int r = gsl_multifit_linear(XX, vY, c, cov, &out_chisq, workspace);  
  165.                 gsl_multifit_linear_free(workspace);  
  166.                 out_factor.resize(c->size,0);  
  167.                 for (size_t i=0;i<c->size;++i)  
  168.                 {  
  169.                     out_factor[i] = gsl_vector_get(c,i);  
  170.                 }  
  171.   
  172.                 gsl_vector_free(vY);  
  173.                 gsl_matrix_free(XX);  
  174.                 gsl_matrix_free(cov);  
  175.                 gsl_vector_free(c);  
  176.   
  177.                 return r;  
  178.             }  
  179.             bool polyfit(const double *x  
  180.                 ,const double *y  
  181.                 ,size_t xyLength  
  182.                 ,unsigned poly_n)  
  183.             {  
  184.                 double chisq;  
  185.                 std::vector<double> factor;  
  186.                 int r = polyfit(x,y,xyLength,poly_n,factor,chisq);  
  187.                 if (0 != r)  
  188.                     return false;  
  189.                 m_goodness = gsl_cdf_chisq_Q(chisq/2.0,(xyLength-2)/2.0);//计算优度  
  190.   
  191.                 clearAll();  
  192.                 for (unsigned i=0;i<poly_n+1;++i)  
  193.                 {  
  194.                     m_factor[i]=factor[i];  
  195.                 }  
  196.                 std::vector<double> yi;  
  197.                 getYis(x,xyLength,yi);  
  198.                 double t;//由于没用到,所以都用t代替  
  199.                 getDeterminateOfCoefficient(y,&yi[0],xyLength,m_ssr,m_sse,t,m_rmse,t);  
  200.   
  201.                 return true;  
  202.             }  
  203.             bool polyfit(const std::vector<double>& x  
  204.                          ,const std::vector<double>& y  
  205.                          ,unsigned plotN)  
  206.             {  
  207.                 size_t n = x.size() > y.size() ? y.size() :x.size();  
  208.                 return polyfit(&x[0],&y[0],n,plotN);  
  209.             }  
  210.   
  211.             double getYi(double x) const  
  212.             {  
  213.                 double ans(0);  
  214.                 for (auto ite = m_factor.begin();ite != m_factor.end();++ite)  
  215.                 {  
  216.                     ans += (ite->second)*pow(x,ite->first);  
  217.                 }  
  218.                 return ans;  
  219.             }  
  220.             void getYis(const double* x,size_t length,std::vector<double>& yis) const  
  221.             {  
  222.                 yis.clear();  
  223.                 yis.resize(length);  
  224.                 for(size_t i=0;i<length;++i)  
  225.                 {  
  226.                     yis[i] = getYi(x[i]);  
  227.                 }  
  228.             }  
  229.             ///  
  230.             /// \brief 获取斜率  
  231.             /// \return 斜率值  
  232.             ///  
  233.             double getSlope() {return m_factor[1];}  
  234.             ///  
  235.             /// \brief 获取截距  
  236.             /// \return 截距值  
  237.             ///  
  238.             double getIntercept() {return m_factor[0];}  
  239.   
  240.             ///  
  241.             /// \brief 回归平方和  
  242.             /// \return 回归平方和  
  243.             ///  
  244.             double getSSR() const {return m_ssr;}  
  245.             double getSSE() const {return m_sse;}  
  246.             double getSST() const {return m_ssr+m_sse;}  
  247.             double getRMSE() const {return m_rmse;}  
  248.             double getRSquare() const {return 1.0-(m_sse/(m_ssr+m_sse));}  
  249.             double getGoodness() const {return m_goodness;}  
  250.         };  
  251.     }  
  252. }  
  253. #endif // CZYMATH_FIT_H  



为了防止重命名,把其放置于czy的命名空间中,此类主要两个函数:

1.求解线性拟合:

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. ///  
  2. /// \brief 直线拟合-一元回归,拟合的结果可以使用getFactor获取,或者使用getSlope获取斜率,getIntercept获取截距  
  3. /// \param x 观察值的x  
  4. /// \param y 观察值的y  
  5. /// \param length x,y数组的长度  
  6. /// \param isSaveFitYs 拟合后的数据是否保存,默认否  
  7. ///  
  8. template<typename T>  
  9. bool linearFit(const std::vector<typename T>& x, const std::vector<typename T>& y,bool isSaveFitYs=false);  
  10. template<typename T>  
  11. bool linearFit(const T* x, const T* y,size_t length,bool isSaveFitYs=false);  


2.多项式拟合:

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. ///  
  2. /// \brief 多项式拟合,拟合y=a0+a1*x+a2*x^2+……+apoly_n*x^poly_n  
  3. /// \param x 观察值的x  
  4. /// \param y 观察值的y  
  5. /// \param length x,y数组的长度  
  6. /// \param poly_n 期望拟合的阶数,若poly_n=2,则y=a0+a1*x+a2*x^2  
  7. /// \param isSaveFitYs 拟合后的数据是否保存,默认是  
  8. ///   
  9. template<typename T>  
  10. void polyfit(const std::vector<typename T>& x,const std::vector<typename T>& y,int poly_n,bool isSaveFitYs=true);  
  11. template<typename T>  
  12. void polyfit(const T* x,const T* y,size_t length,int poly_n,bool isSaveFitYs=true);  


这两个函数都用模板函数形式写,主要是为了能使用于float和double两种数据类型


2.fit类的MFC示范程序

下面看看如何使用这个类,以MFC示范,使用了开源的绘图控件Hight-Speed Charting,使用方法见http://blog.csdn.net/czyt1988/article/details/8740500

新建对话框文件,

对话框资源文件如图所示:


加入下面的这些变量:

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. std::vector<double> m_x,m_y,m_yploy;  
  2. const size_t m_size;  
  3. CChartLineSerie *m_pLineSerie1;  
  4. CChartLineSerie *m_pLineSerie2;  

由于m_size是常量,因此需要在构造函数进行初始化,如:

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. ClineFitDlg::ClineFitDlg(CWnd* pParent /*=NULL*/)  
  2.     : CDialogEx(ClineFitDlg::IDD, pParent)  
  3.     ,m_size(512)  
  4.     ,m_pLineSerie1(NULL)  


初始化两条曲线:

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. CChartAxis *pAxis = NULL;   
  2. pAxis = m_chartCtrl.CreateStandardAxis(CChartCtrl::BottomAxis);  
  3. pAxis->SetAutomatic(true);  
  4. pAxis = m_chartCtrl.CreateStandardAxis(CChartCtrl::LeftAxis);  
  5. pAxis->SetAutomatic(true);  
  6. m_x.resize(m_size);  
  7. m_y.resize(m_size);  
  8. m_yploy.resize(m_size);  
  9. for(size_t i =0;i<m_size;++i)  
  10. {  
  11.     m_x[i] = i;  
  12.     m_y[i] = i+randf(-25,28);  
  13.     m_yploy[i] = 0.005*pow(double(i),2)+0.0012*i+4+randf(-25,25);  
  14. }  
  15. m_chartCtrl.RemoveAllSeries();//先清空  
  16. m_pLineSerie1 = m_chartCtrl.CreateLineSerie();    
  17. m_pLineSerie1->SetSeriesOrdering(poNoOrdering);//设置为无序  
  18. m_pLineSerie1->AddPoints(&m_x[0], &m_y[0], m_size);  
  19. m_pLineSerie1->SetName(_T("线性数据"));  
  20. m_pLineSerie2 = m_chartCtrl.CreateLineSerie();    
  21. m_pLineSerie2->SetSeriesOrdering(poNoOrdering);//设置为无序  
  22. m_pLineSerie2->AddPoints(&m_x[0], &m_yploy[0], m_size);  
  23. m_pLineSerie2->SetName(_T("多项式数据"));  

rangf是随机数生成函数,实现如下:

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. double ClineFitDlg::randf(double min,double max)  
  2. {  
  3.     int minInteger = (int)(min*10000);  
  4.     int maxInteger = (int)(max*10000);  
  5.     int randInteger = rand()*rand();  
  6.     int diffInteger = maxInteger - minInteger;  
  7.     int resultInteger = randInteger % diffInteger + minInteger;  
  8.     return resultInteger/10000.0;  
  9. }  

运行程序,如图所示


线性拟合的使用如下:

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. void ClineFitDlg::OnBnClickedButton1()  
  2. {  
  3.     CString str,strTemp;  
  4.     czy::Fit fit;  
  5.     fit.linearFit(m_x,m_y);  
  6.     str.Format(_T("方程:y=%gx+%g\r\n误差:ssr:%g,sse=%g,rmse:%g,确定系数:%g"),fit.getSlope(),fit.getIntercept()  
  7.         ,fit.getSSR(),fit.getSSE(),fit.getRMSE(),fit.getR_square());  
  8.     GetDlgItemText(IDC_EDIT,strTemp);  
  9.     SetDlgItemText(IDC_EDIT,strTemp+_T("\r\n------------------------\r\n")+str);  
  10.     //在图上绘制拟合的曲线  
  11.     CChartLineSerie* pfitLineSerie1 = m_chartCtrl.CreateLineSerie();      
  12.     std::vector<double> x(2,0),y(2,0);  
  13.     x[0] = 0;x[1] = m_size-1;  
  14.     y[0] = fit.getY(x[0]);y[1] = fit.getY(x[1]);  
  15.     pfitLineSerie1->SetSeriesOrdering(poNoOrdering);//设置为无序  
  16.     pfitLineSerie1->AddPoints(&x[0], &y[0], 2);  
  17.     pfitLineSerie1->SetName(_T("拟合方程"));//SetName的作用将在后面讲到  
  18.     pfitLineSerie1->SetWidth(2);  
  19. }  

需要如下步骤:

  • 声明Fit类,用于头文件在czy命名空间中,因此需要显示声明命名空间名称czy::Fit fit;
  • 把观察数据输入进行拟合,由于是线性拟合,可以使用LinearFit函数,此函数把观察量的x值和y值传入即可进行拟合
  • 拟合完后,拟合的相关结果保存在czy::Fit里面,可以通过相关方法调用,方法在头文件中都有详细说明

运行结果如图所示:



多项式拟合的使用如下:

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. void ClineFitDlg::OnBnClickedButton2()  
  2. {  
  3.     CString str;  
  4.     GetDlgItemText(IDC_EDIT1,str);  
  5.     if (str.IsEmpty())  
  6.     {  
  7.         MessageBox(_T("请输入阶次"),_T("警告"));  
  8.         return;  
  9.     }  
  10.     int n = _ttoi(str);  
  11.     if (n<0)  
  12.     {  
  13.         MessageBox(_T("请输入大于1的阶数"),_T("警告"));  
  14.         return;  
  15.     }  
  16.     czy::Fit fit;  
  17.     fit.polyfit(m_x,m_yploy,n,true);  
  18.     CString strFun(_T("y=")),strTemp(_T(""));  
  19.     for (int i=0;i<fit.getFactorSize();++i)  
  20.     {  
  21.         if (0 == i)  
  22.         {  
  23.             strTemp.Format(_T("%g"),fit.getFactor(i));  
  24.         }  
  25.         else  
  26.         {  
  27.             double fac = fit.getFactor(i);  
  28.             if (fac<0)  
  29.             {  
  30.                 strTemp.Format(_T("%gx^%d"),fac,i);  
  31.             }  
  32.             else  
  33.             {  
  34.                 strTemp.Format(_T("+%gx^%d"),fac,i);  
  35.             }  
  36.         }  
  37.         strFun += strTemp;  
  38.     }  
  39.     str.Format(_T("方程:%s\r\n误差:ssr:%g,sse=%g,rmse:%g,确定系数:%g"),strFun  
  40.         ,fit.getSSR(),fit.getSSE(),fit.getRMSE(),fit.getR_square());  
  41.     GetDlgItemText(IDC_EDIT,strTemp);  
  42.     SetDlgItemText(IDC_EDIT,strTemp+_T("\r\n------------------------\r\n")+str);  
  43.     //绘制拟合后的多项式  
  44.     std::vector<double> yploy;  
  45.     fit.getFitedYs(yploy);  
  46.     CChartLineSerie* pfitLineSerie1 = m_chartCtrl.CreateLineSerie();      
  47.     pfitLineSerie1->SetSeriesOrdering(poNoOrdering);//设置为无序  
  48.     pfitLineSerie1->AddPoints(&m_x[0], &yploy[0], yploy.size());  
  49.     pfitLineSerie1->SetName(_T("多项式拟合方程"));//SetName的作用将在后面讲到  
  50.     pfitLineSerie1->SetWidth(2);  
  51. }  

步骤如下:

  • 和线性拟合一样,声明Fit变量
  • 输入观察值,同时输入需要拟合的阶次,这里输入2阶,就是2项式拟合,最后的布尔变量是标定是否需要把拟合的结果点保存起来,保存点会根据观察的x值计算拟合的y值,保存结果点会花费更多的内存,如果拟合后需要绘制,设为true会更方便,如果只需要拟合的方程,可以设置为false
  • 拟合完后,拟合的相关结果保存在czy::Fit里面,可以通过相关方法调用,方法在头文件中都有详细说明
代码:
[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. for (int i=0;i<fit.getFactorSize();++i)  
  2. {  
  3.     if (0 == i)  
  4.     {  
  5.         strTemp.Format(_T("%g"),fit.getFactor(i));  
  6.     }  
  7.     else  
  8.     {  
  9.         double fac = fit.getFactor(i);  
  10.         if (fac<0)  
  11.         {  
  12.             strTemp.Format(_T("%gx^%d"),fac,i);  
  13.         }  
  14.         else  
  15.         {  
  16.             strTemp.Format(_T("+%gx^%d"),fac,i);  
  17.         }  
  18.     }  
  19.     strFun += strTemp;  
  20. }  

是用于生成方程的,由于系数小于时,打印时会把负号“-”显示,而正数时却不会显示正号,因此需要进行判断,如果小于0就不用添加“+”号,如果大于0就添加“+”号
结果如下:



源代码下载:
C++最小二乘法拟合-(线性拟合和多项式拟合)
0 0
原创粉丝点击