快速幂或矩阵快速幂

来源:互联网 发布:高频注入源码 编辑:程序博客网 时间:2024/05/29 13:31


原文链接


快速幂或者矩阵快速幂在算大指数次方时是很高效的,他的基本原理是二进制,下面的A可以是一个数也可以是一个矩阵(本文特指方阵),若是数就是快速幂算法,若是矩阵就是矩阵快速幂算法,用c++只需把矩阵设成一个类就可以,然后重载一下乘法就可以,注意为矩阵时则ANS=1,应该是ANS=E,E是单位矩阵,即主对角线是1其余的部分都是0的特殊方阵了。

 举个例子若你要算A^7你会怎么算一般你会用O(N)的算法A^7=A*A*A*A*A*A*A也许你觉得这并不慢但是若要你算A^10000000000000000呢,是不是会觉得O(N)的算法也太慢了吧这不得算死我啊,计算机也不想算了,因为有更高效的算法我们把A的指数写成二进制,这样就有了

A^7=A^111(2)  现在我们可以这么算 令ANS=1;MULTI=A  ,N=7

while(N)

{

if(N%2==1) //亦可以写成N&1 或N%2

ANS*=MULTI;

MULTI*=MULTI;

N/=2;//c++中可以写成 N>>=1;直接用位运算更快

}

写出上面的代码的执行过程就是

ANS=1;MULTI=A; 

N=7 ;N%2=1;   ANS*=MULTI; 所以ANS=A;  MULTI*=MULTI; 所以MULTI=A^2

然后 N/=2;N=3; N%2=1; ANS*=MULTI; 所以 ANS=A*A^2=A^3 ; 又MULTI*=MULTI; 所以MULTI=A^4

然后N/=2;N=1;N%2=1;ANS*=MULTI; 所以 ANS=A*A^2*A^4=A^7;又MULTI*=MULTI; 所以MULTI=A^8

然后N/=2;N=0;算法结束  是不是很巧妙呢,实际上用的乘法次数是 6次你可能觉得,那个A^7=A*A*A*A*A*A*A,不也是用了6次乘法吗有什么区别?

那是因为这个算法是log2(n)   (表示以2为底n的对数) 的复杂度,还有一个系数,大约是2 实际上计算次数就是 2*log2(n) 而普通的连乘计算的复杂度是n 乘法计算次数是n-1

这样在n很小时差别不大,但随着n的增长差距会迅速扩大,例如 n=1024时 普通方法得计算1023次乘法,但快速幂最多(因为当上面的程序执行时N的中间结果为偶数那么  ANS*=MULTI,将不会被执行,故实际的计算次数要小于 2*log2(n))只算2*log2(n) =20次乘法,是不是很快!!!!!!!!!!

但是为什么呢?好像还有点不懂。。

实际上A^7=A^1*A^2*A^4这样每次计算乘法乘的因子都是递增的,而且还是指数递增,还有这些因子是可以递推产生的就是可以利用上次的计算每次平方就可以了,这中其实是使用的二进制的思想,因为任意一个数都可以,表示成二进制,故 A^N以定可以写成

A^(一个二进制数如101010)=A^(100000)*A^(00000)*A(1000)*A^(000)*A^(10)*A^(0)=A^(2^5)*A^(2^3)*A^(2^1)

而我们的MULTI 其实是一个数列 A^1,A^2,A^4,A^8,A^16........即A^(2^0),A^(2^1),A^(2^2),A^(2^3),A^(2^4).......................注意到他的指数都是二进制的位权(不知道是不是这个名词,就像十进制的位权是 1 10 100 1000 10000,一样如1243=1*1000+2*100+3*10+4*1;而二进制的1011 是 1*2^(3)+0*2^(2)+1*2^(1)+1*2^(0) 这样是不是应该理解位权了呢?)实际上任何一个A^N都可以写成这个数列的某些项的乘积,因为N始终都可以表示成二进制,而把N表示成二进制后如果某项为1则说明需要乘上MULTI 否则不用乘上MULTI

于是就有了上面的代码,,,,哎怎么感觉我说的还是很不清楚呢?那就没办法

下面附上代码,另外一般要用快速幂的题都要取模 因为指数太大的数是会爆掉int 和long long 的

[cpp] view plaincopyprint?
  1. #include<iostream>  
  2. using namespace std;  
  3. #define mod 1000000007  
  4. long long quick_pow(int n,int base)  
  5. //n是指数 base是底 即计算的是base^n 当然结果是取模了的  
  6. {  
  7.     long long ans=1;//默认ans大于等于1因为不能算负指数  
  8.     long long multi=base;  
  9.     while(n)  
  10.     {  
  11.         if(n%2) ans*=multi;  
  12.         ans%=mod;//由于数太大一般要取模  
  13.         n/=2;  
  14.         multi*=multi;  
  15.         multi%=mod;  
  16.     }  
  17.     return ans;  
  18. }  
  19.   
  20. int main()  
  21. {  
  22.     int n,base;  
  23.     while(cin>>n>>base)  
  24.         cout<<quick_pow(n,base)<<endl;  
  25.     return 0;  
  26. }  

可能你会问了这个算法有什么用呢?其实用的更多是使用矩阵快速幂,算递推式,注意是递推式,简单的如斐波那契数列的第一亿项的结果模上10000000后是多少你还能用递推式去,逐项递推吗?当然不能,这里就可以发挥矩阵快速幂的神威了,那斐波那契数列和矩阵快速幂能有一毛钱的关系?答案是有而且很大

斐波那契的定义是f(1)=f(2)=1; 然后f(n)=f(n-1)+f(n-2) (n>=2) 我们也可以这样定义f(1)=f(2)=1; [f(n),f(n-1)]=[f(n-1),f(n-2)][1,1,1,0],其中[1,1,1,0] 是一个2*2的矩阵 上面一行是1,1,下面一行是1,0,这样就可以化简了写成[f(n),f(n-1)]=[f(2),f(1)]*[1,1,1,0]^(n-2)

化简一下


这样就可以用矩阵快速幂,快速的推出斐波那契数列的第一亿项的值了(当然是取模的值了)是不是很神奇,类似的递推式也可以,化成这种形式,用矩阵快速幂进行计算

下面附一个矩阵快速幂的代码,当然所有矩阵都是要模的


[cpp] view plaincopyprint?
  1. # include<cstdio>  
  2. # include<cstring>  
  3. using namespace std;  
  4. #define NUM 50  
  5. int MAXN,n,mod;  
  6. struct Matrix//矩阵的类  
  7. {  
  8.     int a[NUM][NUM];  
  9.     void init()           //将其初始化为单位矩阵  
  10.     {  
  11.         memset(a,0,sizeof(a));  
  12.         for(int i=0;i<MAXN;i++)  
  13.             a[i][i]=1;  
  14.     }  
  15. } A;  
  16. Matrix mul(Matrix a,Matrix b)  //(a*b)%mod  矩阵乘法  
  17. {  
  18.     Matrix ans;  
  19.     for(int i=0;i<MAXN;i++)  
  20.         for(int j=0;j<MAXN;j++)  
  21.         {  
  22.             ans.a[i][j]=0;  
  23.             for(int k=0;k<MAXN;k++)  
  24.                 ans.a[i][j]+=a.a[i][k]*b.a[k][j];  
  25.             ans.a[i][j]%=mod;  
  26.         }  
  27.     return ans;  
  28. }  
  29.   
  30. Matrix add(Matrix a,Matrix b)  //(a+b)%mod  //矩阵加法  
  31. {  
  32.     int i,j,k;  
  33.     Matrix ans;  
  34.     for(i=0;i<MAXN;i++)  
  35.         for(j=0;j<MAXN;j++)  
  36.         {  
  37.             ans.a[i][j]=a.a[i][j]+b.a[i][j];  
  38.             ans.a[i][j]%=mod;  
  39.         }  
  40.     return ans;  
  41. }  
  42.   
  43. Matrix pow(Matrix a,int n)    //(a^n)%mod  //矩阵快速幂  
  44. {  
  45.     Matrix ans;  
  46.     ans.init();  
  47.     while(n)  
  48.     {  
  49.         if(n%2)//n&1  
  50.             ans=mul(ans,a);  
  51.         n/=2;  
  52.         a=mul(a,a);  
  53.     }  
  54.     return ans;  
  55. }  
  56.   
  57. Matrix sum(Matrix a,int n)  //(a+a^2+a^3....+a^n)%mod// 矩阵的幂和  
  58. {  
  59.     int m;  
  60.     Matrix ans,pre;  
  61.     if(n==1)  
  62.         return a;  
  63.     m=n/2;  
  64.     pre=sum(a,m);                      //[1,n/2]  
  65.     ans=add(pre,mul(pre,pow(a,m)));   //ans=[1,n/2]+a^(n/2)*[1,n/2]  
  66.     if(n&1)  
  67.         ans=add(ans,pow(a,n));          //ans=ans+a^n  
  68.     return ans;  
  69. }  
  70.   
  71. void output(Matrix a)//输出矩阵  
  72. {  
  73.     for(int i=0;i<MAXN;i++)  
  74.         for(int j=0;j<MAXN;j++)  
  75.             printf("%d%c",a.a[i][j],j==MAXN-1?'\n':' ');  
  76. }  
  77. int main()  
  78. {  
  79.     freopen("in.txt","r",stdin);  
  80.     Matrix ans;  
  81.     scanf("%d%d%d",&MAXN,&n,&mod);  
  82.     for(int i=0;i<MAXN;i++)  
  83.         for(int j=0;j<MAXN;j++)  
  84.         {  
  85.             scanf("%d",&A.a[i][j]);  
  86.             A.a[i][j]%=mod;  
  87.         }  
  88.     ans=sum(A,n);  
  89.     output(ans);  
  90.     return 0;  
  91. }  

话说这是那题的代码,我不知道了

//****************************************

再给几道题,有详细的解题思路点击打开链接

//*****************************************

再介绍一题http://acm.hit.edu.cn/hoj/problem/view?id=2255

是哈工大的在线oj上的一个题目,一个类似于斐波那契的题目,题目会给出a,b,p,q,s,e, 其中f(0)=a,f(1)=b,当n>=2时 f(n)=P*f(n-1)+q*f((n-2) 求它组成的数列从第s项起一直加到第e项的和是多少,这题就不能推f(n)的矩阵乘法的递推式了,得推出他的前n项和s(n)的递推式,当然如果会推斐波那契的矩阵形式的递推形式的想必这个就不难了

推理过程如下:

F(N)=S(N)-S(N-1);

又有F(N)=P*F(N-1)+Q*F(N-2)

所以 S(N)-S(N-1)=P*(S(N-1)-S(N-2))+Q*(S(N-2)-S(N-3))

S(N)=(P+1)*S(N-1)+(Q-P)*S(N-2)-Q*S(N-3)

推到这里你应该知道怎么把它化成矩阵乘法的形式了吧?注意到右边有三个S(x)项所以矩阵递推式左边也要有三项写出来就是


化简一下


所以第s项到第e项的和就是S(E)-S(s-1) 注意是S(s-1)

OK大功告成下面就是AC代码

[cpp] view plaincopyprint?
  1. #include<cstdio>  
  2. #include<cstring>  
  3. #include<iostream>  
  4. using namespace std;  
  5. #define NUM 3  
  6. int MAXN=3,n;  
  7. long long const mod=1e7;  
  8. struct Matrix  
  9. {  
  10.     Matrix(){memset(a,0,sizeof(a));}  
  11.     long long a[NUM][NUM];  
  12.     void init()           //将其初始化为单位矩阵  
  13.     {  
  14.         memset(a,0,sizeof(a));  
  15.         for(int i=0;i<MAXN;i++)  
  16.             a[i][i]=1;  
  17.     }  
  18. } A;  
  19. Matrix mul(Matrix a,Matrix b)  //(a*b)%mod  
  20. {  
  21.     Matrix ans;  
  22.     for(int i=0;i<MAXN;i++)  
  23.         for(int j=0;j<MAXN;j++)  
  24.         {  
  25.             ans.a[i][j]=0;  
  26.             for(int k=0;k<MAXN;k++)  
  27.                 ans.a[i][j]+=a.a[i][k]*b.a[k][j];  
  28.             ans.a[i][j]%=mod;  
  29.         }  
  30.     return ans;  
  31. }  
  32. Matrix pow(Matrix a,int n)    //(a^n)%mod  
  33. {  
  34.     Matrix ans;  
  35.     ans.init();  
  36.     while(n)  
  37.     {  
  38.         if(n%2)//n&1  
  39.             ans=mul(ans,a);  
  40.         n/=2;  
  41.         a=mul(a,a);  
  42.     }  
  43.     return ans;  
  44. }  
  45. void output(Matrix a)  
  46. {  
  47.     for(int i=0;i<MAXN;i++)  
  48.     {  
  49.         for(int j=0;j<MAXN;j++)  
  50.             printf("%d ",a.a[i][j]);  
  51.         cout<<endl;  
  52.     }  
  53.   
  54. }  
  55. int main()  
  56. {  
  57.     //freopen("in.txt","r",stdin);  
  58.     long long  a,b,p,q,s,e;  
  59.     int T;cin>>T;  
  60.     while(T--)  
  61.     {  
  62.         Matrix ans;  
  63.         cin>>a>>b>>p>>q>>s>>e;  
  64.         ans.a[0][0]=p+1;  
  65.         ans.a[0][1]=q-p;  
  66.         ans.a[0][2]=-q;  
  67.         ans.a[1][0]=1;  
  68.         ans.a[2][1]=1;  
  69.         //output(ans);  
  70.         int f[3]={a,a+b,a+b+p*b+q*a};  
  71.         long long  ss=0,ee=0;  
  72.         if(s<=3&&s>0)  
  73.             ss=f[s-1];  
  74.         else if(s)  
  75.         {  
  76.             Matrix temp=pow(ans,s-3);  
  77.             for(int i=0;i<3;i++)  
  78.                 ss+=(temp.a[0][i]*f[2-i])%mod;  
  79.         }  
  80.         if(e<=2)  
  81.             ee=f[e];  
  82.         else  
  83.         {  
  84.             Matrix temp=pow(ans,e-2);  
  85.             for(int i=0;i<3;i++)  
  86.                 ee+=(temp.a[0][i]*f[2-i])%mod;  
  87.         }  
  88.         ss=(ss+mod)%mod;//取模运算要注意一下的  
  89.         ee=(ee+mod)%mod;  
  90.         cout<<(ee-ss+mod)%mod<<endl;//这里也要注意的  
  91.   
  92.     }  
  93.     return 0;  
  94. }  

0 0