信源编码作业——LMS算法

来源:互联网 发布:php静态计数器 编辑:程序博客网 时间:2024/05/22 05:20

一:LMS算法说明:

全称Least mean square 算法,又叫最小均方算法,用于修正滤波器参数使均方差(Mean Square Error,MSE)达到最小,均方差公式如下所示: 
                                    这里写图片描述

步骤:

(1)设置变量和参量: 
X(n)为输入向量,或称为训练样本 
W(n)为权值向量 
b(n)为偏差 
d(n)为期望输出 
y(n)为实际输出 
η为学习速率 
n为迭代次数 
(2)初始化,赋给w(0)各一个较小的随机非零值,令n=0 
(3)对于一组输入样本x(n)和对应的期望输出d,计算 
e(n)=d(n)-X^T(n)W(n) 
W(n+1)=W(n)+ηX(n)e(n) 
(4)判断是否满足条件,若满足算法结束,若否n增加1,转入第(3)步继续执行

二、算法实现——LMS算法的代码

  1. const unsigned int nTests   =4;  
  2. const unsigned int nInputs  =2;  
  3. const double rho =0.005;  
  4.    
  5. struct lms_testdata  
  6. {  
  7.     doubleinputs[nInputs];  
  8.     doubleoutput;  
  9. };  
  10.    
  11. double compute_output(constdouble * inputs,double* weights)  
  12. {  
  13.     double sum =0.0;  
  14.     for (int i = 0 ; i < nInputs; ++i)  
  15.     {  
  16.         sum += weights[i]*inputs[i];  
  17.     }  
  18.     //bias  
  19.     sum += weights[nInputs]*1.0;  
  20.     return sum;  
  21. }  
  22. //计算均方差  
  23. double caculate_mse(constlms_testdata * testdata,double * weights)  
  24. {  
  25.     double sum =0.0;  
  26.     for (int i = 0 ; i < nTests ; ++i)  
  27.     {  
  28.         sum += pow(testdata[i].output -compute_output(testdata[i].inputs,weights),2);  
  29.     }  
  30.     return sum/(double)nTests;  
  31. }  
  32. //对计算所得值,进行分类  
  33. int classify_output(doubleoutput)  
  34. {  
  35.     if(output> 0.0)  
  36.         return1;  
  37.     else  
  38.         return-1;  
  39. }  
  40. int _tmain(int argc,_TCHAR* argv[])  
  41. {  
  42.     lms_testdata testdata[nTests] = {  
  43.         {-1.0,-1.0, -1.0},  
  44.         {-1.0, 1.0, -1.0},  
  45.         { 1.0,-1.0, -1.0},  
  46.         { 1.0, 1.0,  1.0}  
  47.     };  
  48.     doubleweights[nInputs + 1] = {0.0};  
  49.     while(caculate_mse(testdata,weights)> 0.26)//计算均方差,如果大于给定值,算法继续  
  50.     {  
  51.         intiTest = rand()%nTests;//随机选择一组数据  
  52.         doubleoutput = compute_output(testdata[iTest].inputs,weights);  
  53.         doubleerr = testdata[iTest].output - output;  
  54.         //调整输入端的权值  
  55.         for (int i = 0 ; i < nInputs ; ++i)  
  56.         {  
  57.             weights[i] = weights[i] + rho * err* testdata[iTest].inputs[i];  
  58.         }  
  59.         weights[nInputs] = weights[nInputs] +rho * err;  
  60.         cout<<"mse:"<<caculate_mse(testdata,weights)<<endl;  
  61.     }  
  62.    
  63.     for(int w = 0 ; w < nInputs + 1 ; ++w)  
  64.     {  
  65.         cout<<"weight"<<w<<":"<<weights[w]<<endl;  
  66.     }  
  67.     cout<<"\n";  
  68.     for (int i = 0 ;i < nTests ; ++i)  
  69.     {  
  70.         cout<<"rightresult:êo"<<testdata[i].output<<"\t";  
  71.         cout<<"caculateresult:" << classify_output(compute_output(testdata[i].inputs,weights))<<endl;  
  72.     }  
  73.     //  
  74.     char temp ;  
  75.     cin>>temp;  
  76.     return 0;  
  77. }  



原创粉丝点击