矩阵幂之和(矩阵乘法)

来源:互联网 发布:华方 线切割编程 编辑:程序博客网 时间:2024/06/14 07:44

2481. [HZOI 2016][POJ3233]矩阵幂之和

时间限制:2 s   内存限制:128 MB

【题目描述】

给定一个n*n的矩阵A和一个正整数k,求S=A+A^2+A^3+...+A^k。

【输入格式】

第一行三个正整数n,k,m。

以下n行,每行n个小于m的非负整数,表示矩阵A。

【输出格式】

n行,每行n个数,表示矩阵S中的每个元素mod m的值。

【样例输入】

2 2 40 11 1

【样例输出】

1 22 3

【数据范围与约定】

对于30%的数据,k<=10^5。

对于60%的数据,m<=10^8。

对于100%的数据,n<=30,k<=10^10,m<=10^18。


刚看这题想到了之前做的一道题,求得是一个数的A+A^2+A^3+...+A^k之和,那个可以用错位相减和乘法逆元来搞,然而这个不行(难道逆矩阵?滑稽)。

首先我们把初始矩阵设为A,单位矩阵设为E,零矩阵设为O,构造出这样一个矩阵:

A E    一次幂: A^2 A+E  二次幂:A^3  A^2+A+E  三次幂: A^4  A^3+A^2+A+E

O E             O   E            O    E                 O    E

那么我们会惊奇的发现,这个矩阵的K+1次幂居然是E+A+A^2+A^3+...+A^k,那么我们再减去一个E就是最终结果。(没错,就是矩阵套矩阵.) 这个题应该还有别的更快的做法,不过个人感觉好复杂的样子。

注意:乘法爆long long,要处理一下。。。

代码:

#include<cstdio>#include<iostream>#include<cstring>#define mem(a,b) memset(a,b,sizeof(a))using namespace std;typedef long long ll;int n;ll K,ki;inline ll mul(ll x,ll y){return (x*y-(ll)(x/(long double)ki*y+1e-3)*ki+ki)%ki;}struct matrix{ll s[32][32];matrix(){mem(s,0);}void init(){mem(s,0);for(int i=0;i<n;i++) s[i][i]=1;}void cl(){mem(s,0);}friend matrix operator * (matrix x,matrix y){matrix z;ll tmp;for(int i=0;i<n;i++)for(int j=0;j<n;j++)for(int k=0;k<n;k++){tmp=mul(x.s[i][k],y.s[k][j])%ki;z.s[i][j]+=tmp;z.s[i][j]%=ki;}return z;}friend matrix operator + (matrix x,matrix y){matrix z;for(int i=0;i<n;i++)for(int j=0;j<n;j++)z.s[i][j]=(x.s[i][j]+y.s[i][j])%ki;return z;}friend matrix operator - (matrix x,matrix y){matrix z;for(int i=0;i<n;i++)for(int j=0;j<n;j++)z.s[i][j]=(x.s[i][j]-y.s[i][j]+ki)%ki;return z;}}A,E,O,fans;struct matrix2{matrix s[4][4];void cl(){for(int i=0;i<2;i++) for(int j=0;j<2;j++) s[i][j].cl();}matrix2(){cl();}void init(){for(int i=0;i<2;i++) s[i][i].init();}friend matrix2 operator * (matrix2 x,matrix2 y){matrix2 z;for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int k=0;k<2;k++){matrix tmp=x.s[i][k]*y.s[k][j];z.s[i][j]=z.s[i][j]+tmp;}return z;}}B,ans;int main(){   freopen("matrix_sum.in","r",stdin);freopen("matrix_sum.out","w",stdout);scanf("%d%lld%lld",&n,&K,&ki);ll tmp;E.init();ans.init();for(int i=0;i<n;i++)for(int j=0;j<n;j++){scanf("%lld",&tmp);A.s[i][j]=tmp;}B.s[0][0]=A;B.s[0][1]=E;B.s[1][0]=O;B.s[1][1]=E;ll edg=K+1;for(;edg;edg>>=1,B=B*B)if(edg&1) ans=ans*B;fans=ans.s[0][1]-E;for(int i=0;i<n;i++){printf("%lld",fans.s[i][0]);for(int j=1;j<n;j++)printf(" %lld",fans.s[i][j]);putchar('\n');}return 0;}


原创粉丝点击