Kalman滤波(三)

来源:互联网 发布:淘宝隐形眼镜可靠吗 编辑:程序博客网 时间:2024/06/07 19:11

今天研究了一下卡尔曼滤波跟踪,同时也看了一下卡尔曼滤波Opencv的源代码,总是看懂了。下面是opencv自带的一个程序,代码如下:

[cpp] view plain copy
  1. // kalman.cpp : 定义控制台应用程序的入口点。  
  2. //  
  3.   
  4. #include "stdafx.h"  
  5.   
  6.   
  7. #include "opencv2/video/tracking.hpp"  
  8. #include "opencv2/highgui/highgui.hpp"  
  9.   
  10. #include <stdio.h>  
  11.   
  12. using namespace cv;  
  13.   
  14. static inline Point calcPoint(Point2f center, double R, double angle)  
  15. {  
  16.     return center + Point2f((float)cos(angle), (float)-sin(angle))*(float)R;  
  17. }  
  18.   
  19. static void help()  
  20. {  
  21.     printf( "\nExamle of c calls to OpenCV's Kalman filter.\n"  
  22. "   Tracking of rotating point.\n"  
  23. "   Rotation speed is constant.\n"  
  24. "   Both state and measurements vectors are 1D (a point angle),\n"  
  25. "   Measurement is the real point angle + gaussian noise.\n"  
  26. "   The real and the estimated points are connected with yellow line segment,\n"  
  27. "   the real and the measured points are connected with red line segment.\n"  
  28. "   (if Kalman filter works correctly,\n"  
  29. "    the yellow segment should be shorter than the red one).\n"  
  30.             "\n"  
  31. "   Pressing any key (except ESC) will reset the tracking with a different speed.\n"  
  32. "   Pressing ESC will stop the program.\n"  
  33.             );  
  34. }  
  35.   
  36. int main(intchar**)  
  37. {  
  38.     help();  
  39.     Mat img(500, 500, CV_8UC3);  
  40.     KalmanFilter KF(2, 1, 0);  
  41.     //[x1,x2]=[角度,角速度]  
  42.     /* 
  43.     运动模型:x1(k+1) = x1(k) + x2(k)*T 
  44.              x2(k+1) = x2(k) 
  45.     状态转移方程: 
  46.     x^ = AX + w 
  47.     测量方程: 
  48.     z = Hx + v 
  49.     */  
  50.     //状态估计值x=[x1,x2] --> state  
  51.     Mat state(2, 1, CV_32F); /* (phi, delta_phi) */  
  52.     Mat processNoise(2, 1, CV_32F);  
  53.     //当前观测值Z=Hx + v ---> measurement  
  54.     Mat measurement = Mat::zeros(1, 1, CV_32F);  
  55.     char code = (char)-1;  
  56.   
  57.     for(;;)  
  58.     {  
  59.         randn( state, Scalar::all(0), Scalar::all(0.1) );  
  60.         //transitionMatrix对应到状态转移方程中的矩阵A  
  61.         KF.transitionMatrix = *(Mat_<float>(2, 2) << 1, 1, 0, 1);  
  62.         //measurementMatrix对应到测量方程的矩阵H  
  63.         setIdentity(KF.measurementMatrix);  
  64.         //processNoiseCov对应过程噪声协方差Q  
  65.         setIdentity(KF.processNoiseCov, Scalar::all(1e-5));  
  66.         //measurementNoiseCov对一个测量噪声协方差R  
  67.         setIdentity(KF.measurementNoiseCov, Scalar::all(1e-1));  
  68.         //errorCovPost对应最优值对应的偏差P(k|k)  
  69.         setIdentity(KF.errorCovPost, Scalar::all(1));  
  70.         //statePost对应系统状态最优值x(k|k)  
  71.         randn(KF.statePost, Scalar::all(0), Scalar::all(0.1));  
  72.   
  73.         for(;;)  
  74.         {  
  75.             Point2f center(img.cols*0.5f, img.rows*0.5f);  
  76.             float R = img.cols/3.f;  
  77.             //角度  
  78.             double stateAngle = state.at<float>(0);  
  79.             Point statePt = calcPoint(center, R, stateAngle);  
  80.   
  81.             //X(k|k-1) = A*X(k-1|k-1)  
  82.             Mat prediction = KF.predict();  
  83.             //角度  
  84.             double predictAngle = prediction.at<float>(0);  
  85.             Point predictPt = calcPoint(center, R, predictAngle);  
  86.   
  87.             randn( measurement, Scalar::all(0), Scalar::all(KF.measurementNoiseCov.at<float>(0)));  
  88.   
  89.             // generate measurement  
  90.             //Z(k) = H*X(k)  
  91.             measurement += KF.measurementMatrix*state;  
  92.   
  93.             //角度  
  94.             double measAngle = measurement.at<float>(0);  
  95.             Point measPt = calcPoint(center, R, measAngle);  
  96.   
  97.             // plot points  
  98.             #define drawCross( center, color, d )                                 \  
  99.                 line( img, Point( center.x - d, center.y - d ),                \  
  100.                              Point( center.x + d, center.y + d ), color, 1, CV_AA, 0); \  
  101.                 line( img, Point( center.x + d, center.y - d ),                \  
  102.                              Point( center.x - d, center.y + d ), color, 1, CV_AA, 0 )  
  103.   
  104.             img = Scalar::all(0);  
  105.             drawCross( statePt, Scalar(255,255,255), 3 );  
  106.             drawCross( measPt, Scalar(0,0,255), 3 );  
  107.             drawCross( predictPt, Scalar(0,255,0), 3 );  
  108.             line( img, statePt, measPt, Scalar(0,0,255), 3, CV_AA, 0 );  
  109.             line( img, statePt, predictPt, Scalar(0,255,255), 3, CV_AA, 0 );  
  110.   
  111.             if(theRNG().uniform(0,4) != 0)  
  112.             //X(k|k) = X(k|k-1) + Kg(k)*[Z(k) - H*X(k|k-1)  
  113.                 KF.correct(measurement);  
  114.   
  115.             randn( processNoise, Scalar(0), Scalar::all(sqrt(KF.processNoiseCov.at<float>(0, 0))));  
  116.             //X(k) = AX(k-1) + W(k)  
  117.             state = KF.transitionMatrix*state + processNoise;  
  118.   
  119.             imshow( "Kalman", img );  
  120.             code = (char)waitKey(100);  
  121.   
  122.             if( code > 0 )  
  123.                 break;  
  124.         }  
  125.         if( code == 27 || code == 'q' || code == 'Q' )  
  126.             break;  
  127.     }  
  128.   
  129.     return 0;  
  130. }  

同时为了更好的理解代码,我们需要知道一下的东西

代码1.

[cpp] view plain copy
  1. Mat statePre;           //!< predicted state (x'(k)): x(k)=A*x(k-1)+B*u(k)  
  2.    Mat statePost;          //!< corrected state (x(k)): x(k)=x'(k)+K(k)*(z(k)-H*x'(k))  
  3.    Mat transitionMatrix;   //!< state transition matrix (A)  
  4.    Mat controlMatrix;      //!< control matrix (B) (not used if there is no control)  
  5.    Mat measurementMatrix;  //!< measurement matrix (H)  
  6.    Mat processNoiseCov;    //!< process noise covariance matrix (Q)  
  7.    Mat measurementNoiseCov;//!< measurement noise covariance matrix (R)  
  8.    Mat errorCovPre;        //!< priori error estimate covariance matrix (P'(k)): P'(k)=A*P(k-1)*At + Q)*/  
  9.    Mat gain;               //!< Kalman gain matrix (K(k)): K(k)=P'(k)*Ht*inv(H*P'(k)*Ht+R)  
  10.    Mat errorCovPost;       //!< posteriori error estimate covariance matrix (P(k)): P(k)=(I-K(k)*H)*P'(k)  

一看上面的注释大概也明白什么意思了。

此外,还需要看2断代码

[cpp] view plain copy
  1. const Mat& KalmanFilter::predict(const Mat& control)  
  2. {  
  3.     // update the state: x'(k) = A*x(k)  
  4.     statePre = transitionMatrix*statePost;  
  5.   
  6.     if( control.data )  
  7.         // x'(k) = x'(k) + B*u(k)  
  8.         statePre += controlMatrix*control;  
  9.   
  10.     // update error covariance matrices: temp1 = A*P(k)  
  11.     temp1 = transitionMatrix*errorCovPost;  
  12.   
  13.     // P'(k) = temp1*At + Q  
  14.     gemm(temp1, transitionMatrix, 1, processNoiseCov, 1, errorCovPre, GEMM_2_T);  
  15.   
  16.     // handle the case when there will be measurement before the next predict.  
  17.     statePre.copyTo(statePost);  
  18.   
  19.     return statePre;  
  20. }  

这个段代码其实就是:X(k|k-1) = A(k-1|k-1) ,同时得到预测结果X(k|k-1)的偏差P(k|k-1)

[cpp] view plain copy
  1. const Mat& KalmanFilter::correct(const Mat& measurement)  
  2. {  
  3.     // temp2 = H*P'(k)  
  4.     temp2 = measurementMatrix * errorCovPre;  
  5.   
  6.     // temp3 = temp2*Ht + R  
  7.     gemm(temp2, measurementMatrix, 1, measurementNoiseCov, 1, temp3, GEMM_2_T);  
  8.   
  9.     // temp4 = inv(temp3)*temp2 = Kt(k)  
  10.     solve(temp3, temp2, temp4, DECOMP_SVD);  
  11.   
  12.     // K(k):卡尔曼增益  
  13.     gain = temp4.t();  
  14.   
  15.     // temp5 = z(k) - H*x'(k)  
  16.     temp5 = measurement - measurementMatrix*statePre;  
  17.   
  18.     // x(k) = x'(k) + K(k)*temp5  
  19.     statePost = statePre + gain*temp5;  
  20.   
  21.     // P(k) = P'(k) - K(k)*temp2  
  22.     errorCovPost = errorCovPre - gain*temp2;  
  23.   
  24.     return statePost;  
  25. }  

上面的代码其实就是:求Kg(k)=P(k|k-1)H' / (HP(k|k-1)H' + R) ,X(k|k) = X(k|k-1) + Kg(k)(Z(k) - HX(k|k-1),P(k|k) = ( 1 - Kg(k)H)P(k|k-1)


原文:http://blog.csdn.net/suky520/article/details/20745479

原创粉丝点击