poj 3233 (矩阵乘法+二分+递归)

来源:互联网 发布:淘宝天猫店铺多少钱 编辑:程序博客网 时间:2024/05/05 02:10

题目分析:矩阵快速幂。首先我们知道 A^x 可以用矩阵快速幂求出来(具体可见poj 3070)。其次可以对k进行二分,每次将规模减半,分k为奇偶两种情况,如当k = 6和k = 7时有:

      k = 6 有: S(6) = (1 + A^3) * (A + A^2 + A^3) = (1 + A^3) * S(3)。
      k = 7 有: S(7) = A + (A + A^4) * (A + A^2 + A^3) = A + (A + A^4) * S(3)。

ps:对矩阵定义成结构体Matrix,求S时用递归,程序会比较直观,好写一点。当然定义成数组,然后再进行一些预处理,效率会更高些。
  注意:
        1.开始的时候,一直递归,层数太多,一直TLE,应该吧算出来的暂时存储的(见错误代码);
        2.注意矩阵的0次幂
        3.注意运算符重载

正确的代码:参考了http://blog.sina.com.cn/s/blog_6635898a0102e1am.html

#include<iostream>#include<cstdio>#include<algorithm>using namespace std;int n,k,m;struct node{     int matrix[50][50];};node a;//运算符重载node operator + (node x,node y)//矩阵x+矩阵y{node ans;for(int i=1;i<=n;i++)for(int j=1;j<=n;j++)ans.matrix[i][j]=(x.matrix[i][j]+y.matrix[i][j])%m;return ans;}node inline mult(node x,node y)//计算矩阵x*y{node c;for(int i=1;i<=n;i++)for(int j=1;j<=n;j++){int ans=0;for(int p=1;p<=n;p++)//{ans+=(x.matrix[i][p]*y.matrix[p][j])%m;ans%=m;}c.matrix[i][j]=ans%m;}return c;}node inline func(node x,int i)//计算矩阵x^i{//printf("%d**\n",i);node temp,c;memset(temp.matrix,0,sizeof(temp.matrix));for(int j=1;j<=n;j++)          temp.matrix[j][j]=1;if(i==0)return temp;if(i==1)return x;    c=func(x,i/2);if(i%2==0)return mult(c,c);else        return mult(mult(c,c),a);}node fun(node A,int x) //计算a^1+a^2+...+a^k{   if(x==1)return A;node B=func(A,(x+1)/2);node  C=fun(A,x/2);   if(x%2==0)   return mult((func(A,0)+B),C);//return B+mult(C,B);   else   return A+mult((A+B),C);//B+mult(C,B)+C;}int main(){while(scanf("%d %d %d",&n,&k,&m)!=EOF){int i,j;        for(i=1;i<=n;i++)for(j=1;j<=n;j++)scanf("%d",&a.matrix[i][j]);        node ans=fun(a,k);for(i=1;i<=n;i++){printf("%d",ans.matrix[i][1]);for(j=2;j<=n;j++)printf(" %d",ans.matrix[i][j]);printf("\n");}}//system("pause");return 0;}



错误的代码:

  k = 6 有: S(6) = (1 + A^3) * (A + A^2 + A^3) = (1 + A^3) * S(3)。
  k = 7 有: S(7) = (1+A^3)*(A+A^2+A^3)+A^7。
睡错误的.........没检查出来为啥.............???????????

#include<iostream>#include<cstdio>#include<algorithm>using namespace std;int n,k,m;struct node{     int matrix[50][50]; bool flag;}arr[11000];node a;//运算符重载node operator + (node x,node y)//矩阵x+矩阵y{node ans;for(int i=1;i<=n;i++)for(int j=1;j<=n;j++)ans.matrix[i][j]=(x.matrix[i][j]+y.matrix[i][j])%m;return ans;}node operator = (node x){/*node ans;for(int i=1;i<=n;i++)for(int j=1;j<=n;j++)ans.matrix[i][j]=x.matrix[i][j];*/return x;}node mult(node x,node y)//计算矩阵x*y{node c;for(int i=1;i<=n;i++)for(int j=1;j<=n;j++){int ans=0;for(int p=1;p<=n;p++)//{ans+=(x.matrix[i][p]*y.matrix[p][j])%m;ans%=m;}c.matrix[i][j]=ans%m;}return c;}node func(node x,int i)//计算矩阵x^i{//printf("%d**\n",i);if(i==1)return x;if(i%2==0)return mult(func(x,i/2),func(x,i/2));else        return mult(mult(func(x,i/2),func(x,i/2)),a);}node fun(int x) //计算a^1+a^2+...+a^k{node temp;if(arr[x].flag==true)return arr[x];   if(x%2==0)   temp=fun(x/2)+mult(func(a,x/2),fun(x/2));   else   temp=fun(x/2)+mult(func(a,x/2),fun(x/2))+func(a,x);   arr[x]=temp;   arr[x].flag=true;   return temp;/*if(x==1)return a;   if(x%2==0)   return fun(x/2)+mult(func(a,x/2),fun(x/2));   else   return  fun(x/2)+mult(func(a,x/2),fun(x/2))+func(a,x);*/}int main(){while(scanf("%d %d %d",&n,&k,&m)!=EOF){        for(int i=1;i<=n;i++)for(int j=1;j<=n;j++)scanf("%d",&a.matrix[i][j]);for(int i=1;i<=10000;i++)arr[i].flag=false;arr[1]=a;arr[1].flag=true;        node ans=fun(k);for(int i=1;i<=n;i++){printf("%d",ans.matrix[i][1]);for(int j=2;j<=n;j++)printf(" %d",ans.matrix[i][j]);printf("\n");}}system("pause");return 0;}