直线拟合、二次曲线拟合、折线拟合和KNN近邻(附代码)

来源:互联网 发布:电信宽带测速软件 编辑:程序博客网 时间:2024/05/21 06:38

一个工程中的应用,需要对一组数据做上面四种形式的拟合回归,并且根据模型对输入做evaluation,就是做一个函数曲线拟合。

下面的RevPre定义了方法和结构,Util是使用案例,其中的Opt的type指示模型要拟合的是哪一种。


1)直线拟合: y = kx+b

2)   二次曲线拟合:y = AX^2 + BX + C

以上两种很典型,不多解释;

3) KNN

KNN的Opt阶段什么也没做,只是保存了所有数据对,然后检测阶段选取最近的K(MIN_KNN_K < K < MAX_KNN_K)个点来计算距离加权后的结果。

4) 折线拟合:

前面先介绍KNN,是因为折线拟合的结果与之很类似,同样保存了一系列的数据对,检测的时候判断处于哪一段。不同之处是做了简化,而不是全部存储起来。

首先对数据对按照x快排,然后以直线拟合判断总体是升序还是降序,接下来就是不断淘汰不符合顺序的点,发生“突起”或者“凹陷”时清除与拟合直线距离远的点,直到所有点都按序排列,形成一条单调折线。


代码:

RevPre.h

#include <iostream>#include <math.h>/************************Model type*********************************/#define MAX_POL_DEPTH3#define MAX_KNN_K10#define MIN_KNN_K1#define MAX(x,y)(x) > (y) ? (x) : (y)#define MIN(x,y)(x) < (y) ? (x) : (y)enum modelType{StraightLine = 0, // defaultCurveAt2,BrokenLine,KNNModel};typedef struct Model{enum modelType type;// Line parametersdoublelineParam[MAX_POL_DEPTH];// Point modeldouble*px, *py;intlen;};Model* CreateModel();void ReleaseModel( Model** _ptr );bool SetOptData( Model* ptr, double *x, double *y, int len );bool Opt( Model *ptr, modelType type );double Predict( Model *ptr, double x );/**************************Polynomial*******************************//* Internal */void CalculatePower(double *powers, int ptNum, int maxDepth, double *x );          //将初始x[i]的值的各幂次方存储在一个二维数组里面 void CalculateParams(double *powers, int ptNum, int maxDepth,  double *params, double *y);//计算正规方程组的系数矩阵 void DirectLU( double *params, int ptNum, int maxDepth, double *x );//列主元LU分解inline void swap(double &,double &);//交换两个变量的值/* External */bool PolynomialOpt( Model *ptr );/************************StraightLine********************************/bool StraightLineOpt( Model *ptr );/************************BrokenLine********************************//*Internal*/int SingleSort( double *index, double *context, int start, int end );void QuickSort( double *index, double *context, int start, int end );int CheckSequence( double *context, int start, int end, bool upTrend );/*External*/bool BrokenLineOpt( Model *ptr );/********************KNN(Lazy-learning)****************************/bool KNNOpt( Model *ptr );

RevPre.cpp

#include "Revise.h"Model* CreateModel(){Model *ptr = new Model;ptr->type = StraightLine;ptr->px = ptr->py = NULL;ptr->len = 0;memset( ptr->lineParam, 0, sizeof(double)*MAX_POL_DEPTH );return ptr;}void ReleaseModel( Model** _ptr ){Model *ptr = *_ptr;if ( ptr->px ) delete[] ptr->px;if ( ptr->py ) delete[] ptr->py;delete ptr;*_ptr = NULL;return ;}bool SetOptData( Model *ptr, double *x, double *y, int len ){if ( !ptr || !x || !y ) return false;if ( !ptr->px ) ptr->px = new double[len];if ( !ptr->py ) ptr->py = new double[len];ptr->len = len;memcpy( ptr->px, x, sizeof(double)*len );memcpy( ptr->py, y, sizeof(double)*len );return true;}bool Opt( Model *ptr, modelType type ){if ( !ptr ) return false;switch( type ){case StraightLine:returnStraightLineOpt( ptr );case CurveAt2:returnPolynomialOpt( ptr );case BrokenLine:returnBrokenLineOpt( ptr );case KNNModel:returnKNNOpt( ptr );default:returnfalse;}}double Predict( Model *ptr, double x ){if ( !ptr ) exit (-1);switch( ptr->type ){case StraightLine:return ptr->lineParam[0] + ptr->lineParam[1]*x;case CurveAt2:return ptr->lineParam[0] + ptr->lineParam[1]*x + ptr->lineParam[2]*x*x;case BrokenLine:{if ( ptr->len < 3 ) exit(-2);int first = 0;if ( x <= ptr->px[0] ){double x0 = ptr->px[0], x1 = ptr->px[1];double y0 = ptr->py[0], y1 = ptr->py[1];return y0 - (x0-x)*(y1-y0)/(x1-x0);}else if ( x >= ptr->px[ptr->len-1] ){double x0 = ptr->px[ptr->len-2], y0 = ptr->py[ptr->len - 2];double x1 = ptr->px[ptr->len-1], y1 = ptr->py[ptr->len - 1];return y1 -(x-x1)*(y0-y1)/(x1-x0);}else{while ( ptr->px[first] < x ) { first ++ ;}first --;double deltay = ptr->py[first+1] - ptr->py[first];double deltax = ptr->px[first+1] - ptr->px[first];return ptr->py[first] + deltay*(x-ptr->px[first])/deltax;}}case KNNModel:{int K = MAX( MIN_KNN_K, MIN( int(ptr->len*0.1), MAX_KNN_K ) );// Prepare the initial K neighboursdouble *dist_team = new double[K];int*idx_team = new int[K];intfarestIdt = -1;doublefarestDist = 0;int id = 0;for ( ; id < K; id ++ ){idx_team[id] = id;dist_team[id] = abs( ptr->px[id] - x );if ( farestDist <= dist_team[id] ){farestIdt = id;farestDist = dist_team[id];}}// Looking for the K nearest neighbourswhile ( id < ptr->len ){if ( abs( ptr->px[id] -x ) < farestDist ){// Update the teamidx_team[farestIdt] = id;dist_team[farestIdt] = abs( ptr->px[id] - x );// Update the farest recordfarestIdt = 0;farestDist = dist_team[0];for ( int searchIdt = 1; searchIdt < K; searchIdt ++ ){if ( dist_team[searchIdt] > farestDist ){farestDist = dist_team[searchIdt];farestIdt = searchIdt;}}}id ++;}// Calculate their contributiondouble res = 0.0;double weightSum = 0.0;for ( int seachIdt = 0; seachIdt < K; seachIdt ++ ){weightSum += 1.0/dist_team[seachIdt];res += 1.0/dist_team[seachIdt]*ptr->py[idx_team[seachIdt]];}delete[] dist_team;delete[] idx_team;return res/weightSum;}default:exit(-2);}}/**************************Polynomial*******************************/bool StraightLineOpt( Model *ptr ){if ( !ptr ) return false; if ( !ptr->px || !ptr->py ) return false;int outLen = 2;int ptNum = ptr->len, maxDepth = outLen;double *powers = new double[maxDepth*ptNum];double *params = new double[maxDepth*(maxDepth+1)];CalculatePower( powers, ptNum, maxDepth, ptr->px );CalculateParams( powers, ptNum, maxDepth, params, ptr->py ); //计算正规方程组的系数矩阵DirectLU( params, ptNum, maxDepth, ptr->lineParam ); //列主元LU分解ptr->type = StraightLine;std::cout<<"-------------------------"<<std::endl;std::cout<<"拟合函数的系数分别为:\n";for( int i=0;i<maxDepth;i++)std::cout<<"a["<<i<<"]="<<ptr->lineParam[i]<<std::endl;std::cout<<"-------------------------"<<std::endl;delete[] powers;delete[] params;return true;}bool PolynomialOpt( Model *ptr ){if ( !ptr ) return false; if ( !ptr->px || !ptr->py ) return false;int outLen = MAX_POL_DEPTH;int ptNum = ptr->len, maxDepth = outLen;double *powers = new double[maxDepth*ptNum];double *params = new double[maxDepth*(maxDepth+1)];CalculatePower( powers, ptNum, maxDepth, ptr->px );CalculateParams( powers, ptNum, maxDepth, params, ptr->py ); //计算正规方程组的系数矩阵DirectLU( params, ptNum, maxDepth, ptr->lineParam ); //列主元LU分解ptr->type = CurveAt2;/*std::cout<<"-------------------------"<<std::endl;std::cout<<"拟合函数的系数分别为:\n";for( int i=0;i<maxDepth;i++)std::cout<<"a["<<i<<"]="<<ptr->lineParam[i]<<std::endl;std::cout<<"-------------------------"<<std::endl;*/delete[] powers;delete[] params;return true;}void CalculatePower(double *powers, int ptNum, int maxDepth, double *x ){if ( !powers || !x ) return ;inti, j, k;doubletemp;for( i = 0; i < maxDepth; i ++ )for( j = 0; j < ptNum; j ++ ){temp = 1;for( k = 0; k < i; k ++ )temp *= x[j];powers[i*ptNum+j] = temp;}return ;}void CalculateParams(double *powers, int ptNum, int maxDepth,  double *params, double *y){if ( !powers || !params || !y ) return ;inti, j, k;doubletemp;intstep = maxDepth + 1;for( i = 0; i < maxDepth; i ++ ){for(j = 0; j < maxDepth; j ++ ){temp = 0;for( k = 0; k < ptNum; k ++ )temp += powers[i*ptNum+k]*powers[j*ptNum+k];params[i*step+j] = temp;}temp = 0;for( k = 0; k < ptNum; k ++ ){temp += y[k]*powers[i*ptNum+k];params[i*step+maxDepth] = temp;}}return ;}inline void swap(double &a,double &b){a=a+b;b=a-b;a=a-b;}void DirectLU( double *params, int ptNum, int maxDepth, double *x ){inti, r, k, j;doublemax;intstep = maxDepth + 1;double *s = new double[maxDepth];double *t = new double[maxDepth];// choose the main elementfor( r = 0; r < maxDepth; r ++ ){max = 0;j = r;for( i = r; i < maxDepth; i ++ ) {s[i] = params[i*step+r];for( k = 0; k < r; k ++ )s[i] -= params[i*step+k] * params[k*step+r];s[i] = abs(s[i]);if( s[i] > max ){j = i;max = s[i];}}// if the "main"element is not @ row r, swap the corresponding element if( j != r ) {for( i = 0; i < maxDepth + 1; i ++ )swap( params[r*step+i], params[j*step+i] );}for( i = r; i < step; i ++ ) for( k = 0; k < r; k ++ ){params[r*step+i] -= params[r*step+k] * params[k*step+i];}for(i = r+1; i < maxDepth; i ++ ) {for ( k = 0; k < r; k ++ )params[i*step+r] -= params[i*step+k] * params[k*step+r];params[i*step+r] /= params[r*step+r];}}for( i = 0; i < maxDepth; i ++ )t[i] = params[ i*step + maxDepth ];for ( i = maxDepth - 1; i >= 0; i -- ) //利用回代法求最终解{for ( r = maxDepth - 1; r > i; r -- )t[i] -= params[ i*step + r ] * x[r];x[i] = t[i]/params[i*step+i];}delete[] s;delete[] t;return ;}/**********************Broken Line***************************/// Quick Sortint SingleSort( double *index, double *context, int start, int end ){if ( end - start < 1 ) return start;int i = start, j = end;double key = index[i];double key_ = context[i];while ( i < j ){while ( index[j] > key && j > i ) j --;if ( index[j] < key ){index[i] = index[j];context[i] = context[j];}while ( index[i] < key && j > i ) i ++;if ( index[i] > key ){index[j] = index[i];context[j] = context[i];}}index[i] = key;context[j] = key_;return i;}void QuickSort( double *index, double *context, int start, int end ){if ( end - start < 1 ) return ; // importantint mid = SingleSort( index, context, start, end );QuickSort( index, context, start, mid - 1 );QuickSort( index, context, mid+ 1, end );}int CheckSequence( double *context, int start, int end, bool upTrend ){int i = start;for ( ; i < end; i ++ ){if ( upTrend && context[i+1] < context[i] ){return i;}if ( !upTrend && context[i] < context[i+1] ){return i;}}return -1;}// Form the broken linebool BrokenLineOpt( Model *ptr ){if ( !ptr ) return false;if ( !ptr->len || !ptr->px || !ptr->py ) return false;// analyse the trend of points and get its approximate lineStraightLineOpt( ptr );double k = ptr->lineParam[1], b = ptr->lineParam[0];bool upTrend = ( k > 0 );// sort the sequence by pyQuickSort( ptr->px, ptr->py, 0, ptr->len - 1 );int oddPoint = 0;while ( (oddPoint = CheckSequence( ptr->py, oddPoint, ptr->len -1, upTrend ) ) != -1 ){double formerErr = abs( k*ptr->px[oddPoint] + b - ptr->py[oddPoint] );double laterErr = abs( k*ptr->px[oddPoint+1] + b - ptr->py[oddPoint+1] );oddPoint = formerErr > laterErr ? oddPoint : oddPoint + 1;// remove the odd pointmemcpy( ptr->py + oddPoint, ptr->py + oddPoint + 1, sizeof(double) );memcpy( ptr->px + oddPoint, ptr->px + oddPoint + 1, sizeof(double) );ptr->len --;oddPoint --;}memset( ptr->lineParam, 0, sizeof(double)*MAX_POL_DEPTH );ptr->type = BrokenLine;return true;}/**********************Lazy Learning***************************/bool KNNOpt( Model *ptr ){// We do nothing as we say it's a lazy-learning method// Only when predict() is called, the learning process is invokedmemset( ptr->lineParam, 0, sizeof(double)*MAX_POL_DEPTH );ptr->type = KNNModel;return true;}

Util.cpp

#include "Revise.h"int _tmain(int argc, _TCHAR* argv[]){double x[10], y[10];for ( int i = 0; i < 10; i ++ ){x[i] = 10 - i;y[i] = (i-2)*(i-2);}Model *model = CreateModel();SetOptData( model, x, y, 10 );Opt( model, BrokenLine );double result = Predict( model, 1.5 );ReleaseModel( &model );return 0;}


0 0
原创粉丝点击