矩阵快速幂

来源:互联网 发布:平面设计和美工哪个好 编辑:程序博客网 时间:2024/05/22 06:57

一、矩阵的基础知识

1.结合性 (AB)C=A(BC).

2.对加法的分配性 (A+B)C=AC+BCC(A+B=CA+CB 

3.对数乘的结合性 k(AB=kA)B =A(kB).

4.关于转置 (AB)'=B'A'

一个矩阵就是一个二维数组,为了方便声明多个矩阵,我们一般会将矩阵封装一个类或定义一个矩阵的结构体,我采用的是后者。

最特殊的矩阵应该就是单位矩阵e了,它的对角线的元素为1,非对角线元素为0。一个n*n的矩阵的0次幂就是单位矩阵。

若A为n×k矩阵,B为k×m矩阵,则它们的乘积AB(有时记做A·B)将是一个n×m矩阵。其乘积矩阵AB的第i行第j列的元素为:

一般矩阵乘法采用朴素的O(n^3)的算法,但是对于一些比较稀疏的矩阵(就是矩阵中0比较多),对于这样的矩阵我们可以采用矩阵的优化,这个算法也适用于一般的矩阵,0特别多时,复杂度可能会降低到O(n^2),实现如下:

还要注意的是,我们要尽可能的减少取模运算,因为取模的复杂度很高,这样我们就可以节约时间了。

矩阵加法就是简单地将对应的位置的两个矩阵的元素相加。

我们一般考虑的是n阶方阵之间的乘法以及n阶方阵与n维向量(把向量看成n×1的矩阵)的乘法。矩阵乘法最重要的性质就是满足结合律,同时它另一个很重要的性质就是不满足交换率,这保证了矩阵的幂运算满足快速幂取模(A^k % MOD)算法,矩阵快速幂其实就是二分指数,避免重复的计算。我们可以采用递归的方式很容易的写出来,但是当指数比较大,或者矩阵比较大得时候,我们就会出现栈溢出的状况,不断RE(我就被坑过)。所以还是写成迭代的方式比较好。

制作矩阵图一般要遵循以下几个步骤: 

1、列出质量因素: 

2、把成对对因素排列成行和列,表示其对应关系 

3、选择合适的矩阵图类型 

4、在成对因素交点处表示其关系程度,一般凭经验进行定性判断,可分为三种:关系密切、关系较密切、关系一般(或可能有关系),并用不同符号表示 

5、根据关系程度确定必须控制的重点因素 

6、针对重点因素作对策表。

二、矩阵快速幂的应用

7、poj3070 是求解菲波那切数列,f(n)=f(n-1)+f(n-2),如果我们一个个递推求解,当n特别大的时候复杂度就会变的很高,对于f(n)= a*f(n-1)+b*f(n-2),在矩阵运算中我们会发现这样一组公式:

到知道这个公式后我们就采用矩阵快速幂的方法可以求解f(n)

[cpp] view plain copy
  1. #include <iostream>  
  2. #include <cstdio>  
  3. #include <cstring>  
  4. using namespace std;  
  5. struct mat{  
  6.     int at[2][2];  
  7. };  
  8. mat d;  
  9. int n,mod;  
  10. mat mul(mat a,mat b)  
  11. {  
  12.     mat t;  
  13.     memset(t.at,0,sizeof(t.at));  
  14.     for(int i=0;i<n;++i)  
  15.     {  
  16.         for(int k=0;k<n;++k)  
  17.         {  
  18.             if(a.at[i][k])  
  19.             for(int j=0;j<n;++j)  
  20.             {  
  21.                 t.at[i][j]+=a.at[i][k]*b.at[k][j];  
  22.                 if(t.at[i][j]>=mod){t.at[i][j]%=mod;}  
  23.             }  
  24.         }  
  25.     }  
  26.     return t;  
  27. }  
  28. mat expo(mat p,int k)  
  29. {  
  30.     if(k==1)return p;  
  31.     mat e;  
  32.     memset(e.at,0,sizeof(e.at));  
  33.     for(int i=0;i<n;++i){e.at[i][i]=1;}  
  34.     if(k==0)return e;  
  35.     while(k)  
  36.     {  
  37.         if(k&1)e=mul(p,e);  
  38.         p=mul(p,p);  
  39.         k>>=1;  
  40.     }  
  41.     return e;  
  42. }  
  43. int main()  
  44. {  
  45.     n=2;mod=10000;  
  46.     d.at[1][1]=0;  
  47.     d.at[0][0]=d.at[1][0]=d.at[0][1]=1;  
  48.     int k;  
  49.     while(~scanf("%d",&k))  
  50.     {  
  51.         if(k==-1)break;  
  52.         mat ret=expo(d,k);  
  53.         int ans=ret.at[0][1]%mod;  
  54.         printf("%d\n",ans);  
  55.     }  
  56.     return 0;  
  57. }  


 

2poj3233题意:给出矩阵A,求S = A + A^2 + A^3 + … + A^k 二分和

 

[cpp] view plain copy
  1. #include <iostream>  
  2. #include <cstdio>  
  3. #include <cstring>  
  4. using namespace std;  
  5. #define LL long long  
  6. int n,m,k;   
  7. int MOD;  
  8. struct mat {  
  9.     int at[40][40];  
  10. };  
  11. mat d;  
  12. mat mul(mat a, mat b)   
  13. {  
  14.     mat ret;  
  15.     memset(ret.at,0,sizeof(ret.at));  
  16.     for (int i=0;i<n;++i)  
  17.     {  
  18.         for (int k=0;k<n;++k)   
  19.         {  
  20.             if(a.at[i][k])  
  21.             for (int j=0;j<n;++j)  
  22.             {  
  23.                 ret.at[i][j]+=a.at[i][k]*b.at[k][j];  
  24.                 if(ret.at[i][j]>=MOD){ret.at[i][j]%=MOD;}  
  25.             }  
  26.         }  
  27.     }  
  28.     return ret;  
  29. }  
  30.   
  31. mat expo(mat a, int k)   
  32. {  
  33.     if(k==1)return a;  
  34.     mat e;  
  35.     memset(e.at,0,sizeof(e.at));  
  36.     for(int i=0;i<n;++i){e.at[i][i]=1;}  
  37.     if(k==0)return e;  
  38.     while(k)  
  39.     {  
  40.         if(k&1)e=mul(a,e);  
  41.         a=mul(a,a);  
  42.         k>>=1;  
  43.     }  
  44.     return e;  
  45. }  
  46.   
  47. mat add(mat a,mat b)  
  48. {  
  49.     mat t;  
  50.     for(int i=0;i<n;++i)  
  51.     {  
  52.         for(int j=0;j<n;++j)  
  53.         {   
  54.             t.at[i][j]=(a.at[i][j]+b.at[i][j]);  
  55.             if(t.at[i][j]>=MOD){t.at[i][j]%=MOD;}  
  56.         }  
  57.     }  
  58.     return t;  
  59. }  
  60.   
  61. void print(mat ans)  
  62. {  
  63.     for(int i=0;i<n;++i)  
  64.     {  
  65.         for(int j=0;j<n;++j)  
  66.         {  
  67.             if(j==0){printf("%d",ans.at[i][j]);continue;}  
  68.             printf(" %d",ans.at[i][j]);  
  69.         }  
  70.         printf("\n");  
  71.     }  
  72. }  
  73.   
  74. mat sum(int k)  
  75. {  
  76.     if(k==1){return d;}  
  77.     if(k&1)  
  78.     {  
  79.         return add(sum(k-1),expo(d,k));  
  80.     }  
  81.     else  
  82.     {  
  83.         mat s=sum(k>>1);  
  84.         return add(s,mul(s,expo(d,k>>1)));  
  85.     }  
  86. }  
  87. int main()  
  88. {  
  89.     while(~scanf("%d%d%d",&n,&k,&m))  
  90.     {  
  91.         MOD=m;  
  92.         mat ans,t;  
  93.         for(int i=0;i<n;++i)  
  94.         {  
  95.             for(int j=0;j<n;++j)  
  96.             {  
  97.                 scanf("%d",&d.at[i][j]);  
  98.                 if(d.at[i][j]>=m)  
  99.                 {  
  100.                     d.at[i][j]%=m;  
  101.                 }  
  102.             }  
  103.         }  
  104.         ans=sum(k);  
  105.         print(ans);  
  106.     }  
  107.     return 0;   
  108. }  


3poj3735

   题意:有n只猫咪,开始时每只猫咪有花生0颗,现有一组操作,由下面三个中的k个操作组成:

   1. g i i只猫咪一颗花生米

   2. e i 让第i只猫咪吃掉它拥有的所有花生米

   3. s i j 将猫咪i与猫咪j的拥有的花生米交换

   现将上述一组操作做m次后,问每只猫咪有多少颗花生?

分析:刚开始每只猫都没有花生,所以我们要在单位矩阵上构建矩阵。给第i只猫一个花生米,那么++met[0][i],让第i只猫吃掉所有的花生米,就令第i列清空,喵咪i与猫咪j交换花生米,就令第i列和第j列互换。矩阵就这样构造完毕,操作m次,我们就可以矩阵快速幂计算了。

[cpp] view plain copy
  1. #include <iostream>  
  2. #include <cstring>  
  3. #include <cstdio>  
  4. #define LL long long  
  5. using namespace std;  
  6. struct met{  
  7.     LL at[105][105];  
  8. };  
  9. met ret,d;  
  10. LL n,m,k;  
  11. met mul(met a,met b)  
  12. {  
  13.     memset(ret.at,0,sizeof(ret.at));  
  14.     for(int i=0;i<=n;++i)  
  15.     {  
  16.         for(int k=0;k<=n;++k)  
  17.         {  
  18.             if(a.at[i][k])  
  19.             {  
  20.                 for(int j=0;j<=n;++j)  
  21.                 {  
  22.                     ret.at[i][j]+=a.at[i][k]*b.at[k][j];  
  23.                 }  
  24.             }  
  25.         }  
  26.     }  
  27.     return ret;  
  28. }  
  29.   
  30. met expo(met a,LL k)  
  31. {  
  32.     if(k==1) return a;  
  33.     met e;  
  34.     memset(e.at,0,sizeof(e.at));  
  35.     for(int i=0;i<=n;++i){e.at[i][i]=1;}  
  36.     if(k==0)return e;  
  37.     while(k)  
  38.     {  
  39.         if(k&1)e=mul(e,a);  
  40.         k>>=1;  
  41.         a=mul(a,a);  
  42.     }  
  43.     return e;  
  44. }  
  45.   
  46.   
  47. int main()  
  48. {  
  49.     while(~scanf("%lld%lld%lld",&n,&m,&k))  
  50.     {  
  51.         LL a,b;  
  52.         char ch[5];  
  53.         if(!n&&!k&&!m)break;  
  54.         memset(d.at,0,sizeof(d.at));  
  55.         for(int i=0;i<=n;++i)  
  56.         {d.at[i][i]=1;}  
  57.         while(k--)  
  58.         {  
  59.             scanf("%s",ch);  
  60.             if(ch[0]=='g')  
  61.             {  
  62.                 scanf("%lld",&a);  
  63.                 d.at[0][a]++;         
  64.             }  
  65.             else if(ch[0]=='e')  
  66.             {  
  67.                 scanf("%lld",&a);  
  68.                 for(int i=0;i<=n;++i)  
  69.                 {  
  70.                     d.at[i][a]=0;     
  71.                 }  
  72.             }  
  73.             else {  
  74.                 scanf("%lld%lld",&a,&b);  
  75.                 for(int i=0;i<=n;++i)  
  76.                 {  
  77.                     LL t=d.at[i][a];  
  78.                     d.at[i][a]=d.at[i][b];  
  79.                     d.at[i][b]=t;  
  80.                 }  
  81.   
  82.             }  
  83.         }  
  84.         met ans=expo(d,m);  
  85.         printf("%lld",ans.at[0][1]);  
  86.         for(int i=2;i<=n;++i)  
  87.         {  
  88.             printf(" %lld",ans.at[0][i]);  
  89.         }  
  90.         printf("\n");  
  91.   
  92.     }  
  93.     return 0;   
  94. }  


 

4poj3150题目大意:给定n1<=n<=500)个数字和一个数字m,这n个数字组成一个环(a0,a1.....an-1)。如果对ai进行一次d-step操作,那么ai的值变为与ai的距离小于d的所有数字之和模m。求对此环进行Kd-stepK<=10000000)后这个环的数字会变为多少。

分析:首先我们要构造矩阵,我们会得到一个500*500的矩阵,那么代码的复杂度就会变成O(log(k)*n^3),很明显这么高的复杂度会超时的。但是我们发现这个矩阵是一个循环矩阵, 第i行都是第i-1行,右移一位得到的,即a[i][j]=a[i-1][j-1]。很容易我们就可以发现循环矩阵a和循环矩阵b的乘积矩阵cc[i][j]=sum(a[i][k]*b[k][j])=sum(a[i-1][k-1]*b[j-1][k-1])=c[i-1][j-1]。那么矩阵c也是一个循环矩阵,在做矩阵乘法的时候我们只需要算出第一行的值,其余行直接右移就可以得到,那么算法的复杂度就会变为O(log(k)*n^2)。还需注意的是对于数据范围会超int,要用long long,还有由于矩阵太大了,在函数中申请不了那么大得空间,所以采用指针的方法去写函数。

[cpp] view plain copy
  1. #include <iostream>  
  2. #include <cstdio>  
  3. #include <cstring>  
  4. #define LL long long  
  5. using namespace std;  
  6. const int maxn=502;  
  7. int n,m,d,k;  
  8. LL tmp[maxn][maxn],e[maxn][maxn],c[maxn][maxn];  
  9. void mul(LL a[][maxn],LL b[][maxn])  
  10. {  
  11.     memset(c,0,sizeof(c));  
  12.     for(int k=0;k<n;++k)  
  13.     {  
  14.         if(a[0][k])  
  15.         for(int j=0;j<n;++j)  
  16.         {  
  17.             c[0][j]+=a[0][k]*b[k][j];  
  18.             if(c[0][j]>=m){c[0][j]%=m;}  
  19.         }  
  20.     }  
  21.     for(int i=1;i<n;++i)  
  22.     {  
  23.         for(int j=0;j<n;++j)   
  24.         {  
  25.             c[i][j]=c[i-1][(j-1+n)%n];  
  26.         }  
  27.     }  
  28.     for(int i=0;i<n;++i)  
  29.     {  
  30.         for(int j=0;j<n;++j)  
  31.         {  
  32.             b[i][j]=c[i][j];  
  33.         }  
  34.     }  
  35. }  
  36.   
  37. void expo(LL a[][maxn],int k)  
  38. {  
  39.     if(k==1){  
  40.         for(int i=0;i<n;++i)  
  41.         {  
  42.             for(int j=0;j<n;++j)  
  43.             {  
  44.                 e[i][j]=a[i][j];  
  45.             }  
  46.         }  
  47.         return;  
  48.     }  
  49.     memset(e,0,sizeof(e));  
  50.     for(int i=0;i<n;++i){e[i][i]=1;}  
  51.     while(k)  
  52.     {  
  53.         if(k&1){mul(a,e);}  
  54.         mul(a,a);  
  55.         k>>=1;  
  56.     }  
  57. }  
  58. int main()  
  59. {  
  60.     LL dat[maxn];  
  61.     scanf("%d%d%d%d",&n,&m,&d,&k);  
  62.     for(int i=0;i<n;++i)  
  63.     {  
  64.         scanf("%lld",&dat[i]);  
  65.         tmp[0][i]=0;  
  66.     }  
  67.     tmp[0][0]=1;  
  68.     for(int i=1;i<=d;++i)  
  69.     {  
  70.         tmp[0][i]=tmp[0][n-i]=1;  
  71.     }  
  72.     for(int i=1;i<n;++i)  
  73.     {  
  74.         for(int j=0;j<n;++j)  
  75.         {  
  76.             tmp[i][j]=tmp[i-1][(j-1+n)%n];  
  77.         }  
  78.     }  
  79.     expo(tmp,k);  
  80.     LL ans[maxn];  
  81.     memset(ans,0,sizeof(ans));  
  82.     for(int i=0;i<n;++i)  
  83.     {  
  84.         for(int j=0;j<n;++j)  
  85.         {  
  86.             ans[i]+=e[i][j]*dat[j];  
  87.             if(ans[i]>=m){ans[i]%=m;}  
  88.         }  
  89.     }  
  90.     printf("%lld",ans[0]);  
  91.     for(int i=1;i<n;++i)  
  92.     {  
  93.         printf(" %lld",ans[i]);  
  94.     }  
  95.     printf("\n");  
  96.     return 0;  
  97. }  


 

对于这道题,网上还有一段神代码

[cpp] view plain copy
  1. #include <iostream>  
  2. #include <cstdio>  
  3. #include <cstring>  
  4. #define LL long long  
  5. using namespace std;  
  6. int n,m,d,k;  
  7. void mul(LL a[],LL b[])  
  8. {  
  9.       int i,j;  
  10.       LL c[501];  
  11.       for(i=0;i<n;++i)for(c[i]=j=0;j<n;++j)c[i]+=a[j]*b[i>=j?(i-j):(n+i-j)];  
  12.       for(i=0;i<n;b[i]=c[i++]%m);                       
  13. }  
  14. LL init[501],tmp[501];  
  15. int main()  
  16. {  
  17.     int i,j;  
  18.     scanf("%d%d%d%d",&n,&m,&d,&k);  
  19.     for(i=0;i<n;++i)scanf("%lld",&init[i]);  
  20.     for(tmp[0]=i=1;i<=d;++i)tmp[i]=tmp[n-i]=1;  
  21.     while(k)  
  22.     {  
  23.             if(k&1)mul(tmp,init);  
  24.             mul(tmp,tmp);  
  25.             k>>=1;       
  26.     }  
  27.     for(i=0;i<n;++i)if(i)printf(" %lld",init[i]);else printf("%lld",init[i]);  
  28.     printf("\n");  
  29.     return 0;  
  30. }  

1 0
原创粉丝点击