hdu4968(矩阵快速幂)

来源:互联网 发布:淘宝刷粉丝 编辑:程序博客网 时间:2024/05/19 05:32

题意:给A、B矩阵,C = A*B,D = C^(N*N),求D中所有元素取模6的和,A是1000*6的矩阵,B是6*1000的矩阵

思路分析:如果直接算的话,矩阵快速幂的时间复杂度是O(n^3logn*n),肯定会超时,所以利用公式(A*B)^n = A*(B*A)^(n-1)*B,这样矩阵快速幂的时间复杂度为O(k^3log(n*n)).

代码如下:

#include<iostream>#include<algorithm>#include<cstring>#include<string>#include<stack>#include<queue>#include<set>#include<map>#include<stdio.h>#include<stdlib.h>#include<math.h>#define N 1005#define MOD 6#define inf 0x7ffffff#define eps 1e-9#define pi acos(-1.0)using namespace std;struct node{    int m[10][10];    node()    {        memset(m,0,sizeof(m));    }};int a[N][N],b[N][N],d[N][N];node I;void init(){    int i,j;    for(i = 0; i < 10; i++)        for(j = 0; j < 10; j++)            if(i == j) I.m[i][j] = 1;            else I.m[i][j] = 0;}node matrixmul(node a,node b,int k){    int i,j,o;    node c;    for(i = 0; i < k; i++)        for(j = 0; j < k; j++)            for(o = 0; o < k; o++)                c.m[i][j] = (c.m[i][j] + a.m[i][o]*b.m[o][j])%MOD;    return c;}node quickpow(node a,int n,int k){    node b,c;    b = a;    c = I;    while(n)    {        if(n&1)            c = matrixmul(b,c,k);        n >>= 1;        b = matrixmul(b,b,k);    }    return c;}int main(){//freopen("input.txt","r",stdin);//freopen("output.txt","w",stdout);    init();    int n,k;    while(scanf("%d%d",&n,&k))    {        if(n == 0 && k == 0) break;        int i,j,o;        for(i = 0; i < n; i++)            for(j = 0; j < k; j++){                scanf("%d",&a[i][j]);                a[i][j] %= MOD;            }        for(i = 0; i < k; i++)            for(j = 0; j < n; j++){                scanf("%d",&b[i][j]);                b[i][j] %= MOD;            }        node c;        for(i = 0; i < k; i++)            for(j = 0; j < k; j++)                for(o = 0; o < n; o++)                    c.m[i][j] = (c.m[i][j] + b[i][o]*a[o][j])%MOD;       // for(i = 0; i < k; i++){         //   for(j = 0; j < k; j++)           //     cout<<c.m[i][j]<<" ";            //cout<<endl;}        c = quickpow(c,n*n-1,k);        memset(d,0,sizeof(d));        for(i = 0; i < n; i++)            for(j = 0; j < k; j++)                for(o = 0; o < k; o++)                    d[i][j] = (d[i][j] + a[i][o]*c.m[o][j])%MOD;        memset(a,0,sizeof(a));        for(i = 0; i < n; i++)            for(j = 0; j < n; j++)                for(o = 0; o < k; o++)                    a[i][j] += (d[i][o]*b[o][j])%MOD;        int ans = 0;        for(i = 0; i < n; i++)        {            for(j = 0; j < n; j++)                ans += a[i][j]%MOD;        }        printf("%d\n",ans);    }    return 0;}


0 0
原创粉丝点击