【4920Matrix multiplication】矩阵乘法优化+输入挂

来源:互联网 发布:赛鸽记录软件 编辑:程序博客网 时间:2024/06/05 05:11

Matrix multiplication

Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 131072/131072 K (Java/Others)
Total Submission(s): 1121    Accepted Submission(s): 474


Problem Description
Given two matrices A and B of size n×n, find the product of them.

bobo hates big integers. So you are only asked to find the result modulo 3.
 

Input
The input consists of several tests. For each tests:

The first line contains n (1≤n≤800). Each of the following n lines contain n integers -- the description of the matrix A. The j-th integer in the i-th line equals Aij. The next n lines describe the matrix B in similar format (0≤Aij,Bij≤109).
 

Output
For each tests:

Print n lines. Each of them contain n integers -- the matrix A×B in similar format.
 

Sample Input
10120 12 34 56 7
 

Sample Output
00 12 1
 

Author
Xiaoxu Guo (ftiasch)
 

Source
2014 Multi-University Training Contest 5

坑爹啊~~~还尼玛有卡常数的,Strassen写到哭有木有啊!!

#define DeBUG#include <iostream>#include <cstdio>#include <cstring>#include <cmath>#include <cstdlib>#include <algorithm>#include <vector>#include <stack>#include <queue>#include <string>#include <set>#include <sstream>#include <map>#include <list>#include <bitset>using namespace std ;#define zero {0}#define INF 0x3f3f3f3f#define EPS 1e-6#define TRUE true#define FALSE falsetypedef long long LL;const double PI = acos(-1.0);//#pragma comment(linker, "/STACK:102400000,102400000")inline int sgn(double x){    return fabs(x) < EPS ? 0 : (x < 0 ? -1 : 1);}#define N 810int n;inline char read(){    char s = 0, t;    while (t = getchar(), t > 47)    {        s += t - '0';    }    return s % 3;}char a[N][N], b[N][N];int val[3][3];int t[N][N];void init(){    for (int i = 0; i < 3; i++)    {        for (int j = 0; j < 3; j++)        {            val[i][j] = (i * j) % 3;        }    }}int main(){#ifdef DeBUGs    freopen("C:\\Users\\Sky\\Desktop\\1.in", "r", stdin);#endif    int i, j, k;    int t;    init();    while (scanf("%d", &n) + 1)    {        getchar();        for (i = 0; i < n; i++)        {            for (j = 0; j < n; j++)            {                a[i][j] = read();            }        }        for (i = 0; i < n; i++)        {            for (j = 0; j < n; j++)            {                b[i][j] = read();            }        }        for (i = 0; i < n; i++)        {            t = 0;            for (k = 0; k < n; k++)                t += val[a[i][k]][b[k][0]];            putchar(t % 3 + '0');            for (j = 1; j < n; j++)            {                t = 0;                for (k = 0; k < n; k++)                    t += val[a[i][k]][b[k][j]];                putchar(' ');                putchar(t % 3 + '0');            }            printf("\n");        }    }    return 0;}

另一份代码
#include <cstdio>#include <cstring>#include <algorithm>using namespace std;inline void rd(int &ret){    char c;    do    {        c = getchar();    }    while (c < '0' || c > '9');    ret = c - '0';    while ((c = getchar()) >= '0' && c <= '9')        ret = ret * 10 + ( c - '0' );}inline void ot(int a)    //输出外挂{    if (a > 9)        ot(a / 10);    putchar(a % 10 + '0');}const int MAX_N = 807;int n;int a[MAX_N][MAX_N], b[MAX_N][MAX_N];int c[MAX_N][MAX_N];int main(){    while (1 == scanf("%d", &n))    {        for (int i = 0; i < n; ++i)        {            for (int j = 0; j < n; ++j)            {                int x;                rd(x);                a[i][j] = x % 3;            }        }        for (int i = 0; i < n; ++i)        {            for (int j = 0; j < n; ++j)            {                int x;                rd(x);                b[i][j] = x % 3;            }        }        memset(c, 0, sizeof(c));        //注意这里矩阵乘法优化        for (int i = 0; i < n; ++i)        {            for (int k = 0; k < n; ++k)            {                if (a[i][k] == 0) continue;                for (int j = 0; j < n; ++j)                {                    c[i][j] += a[i][k] * b[k][j];                }            }        }        for (int i = 0; i < n; ++i)        {            for (int j = 0; j < n; ++j)            {                if (j == 0) ot(c[i][j] % 3);                else                {                    putchar(' ');                    ot(c[i][j] % 3);                }            }            puts("");        }    }    return 0;}


这份代码写完了虽然看着让人想哭,但还是贴这里吧,希望某某年可以用得着
#define DeBUG#include <iostream>#include <cstdio>#include <cstring>#include <cmath>#include <cstdlib>#include <algorithm>#include <vector>#include <stack>#include <queue>#include <string>#include <set>#include <sstream>#include <map>#include <list>#include <bitset>using namespace std ;#define zero {0}#define INF 0x3f3f3f3f#define EPS 1e-6#define TRUE true#define FALSE falsetypedef long long LL;const double PI = acos(-1.0);//#pragma comment(linker, "/STACK:102400000,102400000")inline int sgn(double x){    return fabs(x) < EPS ? 0 : (x < 0 ? -1 : 1);}#define N 100005int **A, * *B, * *C;int mod = 3;void init(int n){    A = new int *[n];    B = new int *[n];    C = new int *[n];    for (int i = 0; i < n; i++)    {        A[i] = new int[n];        B[i] = new int[n];        C[i] = new int[n];    }}inline void clear(int n){    for (int i = 0; i < n; i++)    {        for (int j = 0; j < n; j++)        {            A[i][j] = B[i][j] = C[i][j] = 0;        }    }}inline void Divide(int n, int **A, int **A11, int **A12, int **A21, int **A22){    int i, j;    for (i = 0; i < n; i++)        for (j = 0; j < n; j++)        {            A11[i][j] = A[i][j];            A12[i][j] = A[i][j + n];            A21[i][j] = A[i + n][j];            A22[i][j] = A[i + n][j + n];        }}inline void Merge(int n, int **A, int **A11, int **A12, int **A21, int **A22){    int i, j;    for (i = 0; i < n; i++)        for (j = 0; j < n; j++)        {            A[i][j] = A11[i][j];            A[i][j + n] = A12[i][j];            A[i + n][j] = A21[i][j];            A[i + n][j + n] = A22[i][j];        }}inline void Sub(int n, int **A, int **B, int **C){    int i, j;    for (i = 0; i < n; i++)        for (j = 0; j < n; j++)            C[i][j] = (A[i][j] - B[i][j]) % mod ;}inline void Add(int n, int **A, int **B, int **C){    int i, j;    for (i = 0; i < n; i++)        for (j = 0; j < n; j++)            C[i][j] = (A[i][j] + B[i][j]) % mod;}inline void freeit(int **A, int n){    for (int i = 0; i < n; i++)    {        delete []A[i];    }}inline int read(){    char s = 0, t;    while (t = getchar(), t > 47)    {        s += t - '0';    }    return s % 3;}inline void Mutiply(int n, int **A, int **B, int **M){    if (n <= 256)    {        for (int i = 0; i < n; i++)            for (int j = 0; j < n; j++)                M[i][j] = 0;        for (int i = 0; i < n; i++)        {            for (int j = 0; j < n; j++)            {                for (int k = 0; k < n; k++)                {                    M[i][k] += (A[i][j] * B[j][k]) % mod;                }            }        }    }    else    {        n = n / 2;        int **A11, **A12, **A21, **A22;        int **B11, **B12, **B21, **B22;        int **M11, **M12, **M21, **M22;        int **M1, **M2, **M3, **M4, **M5, **M6, **M7;        int **T1, **T2;        A11 = new int *[n];        A12 = new int *[n];        A21 = new int *[n];        A22 = new int *[n];        B11 = new int *[n];        B12 = new int *[n];        B21 = new int *[n];        B22 = new int *[n];        M11 = new int *[n];        M12 = new int *[n];        M21 = new int *[n];        M22 = new int *[n];        M1 = new int *[n];        M2 = new int *[n];        M3 = new int *[n];        M4 = new int *[n];        M5 = new int *[n];        M6 = new int *[n];        M7 = new int *[n];        T1 = new int *[n];        T2 = new int *[n];        int i;        for (i = 0; i < n; i++)        {            A11[i] = new int[n];            A12[i] = new int[n];            A21[i] = new int[n];            A22[i] = new int[n];            B11[i] = new int[n];            B12[i] = new int[n];            B21[i] = new int[n];            B22[i] = new int[n];            M11[i] = new int[n];            M12[i] = new int[n];            M21[i] = new int[n];            M22[i] = new int[n];            M1[i] = new int[n];            M2[i] = new int[n];            M3[i] = new int[n];            M4[i] = new int[n];            M5[i] = new int[n];            M6[i] = new int[n];            M7[i] = new int[n];            T1[i] = new int[n];            T2[i] = new int[n];        }        Divide(n, A, A11, A12, A21, A22);        Divide(n, B, B11, B12, B21, B22);        Sub(n, B12, B22, T1);        Mutiply(n, A11, T1, M1);        Add(n, A11, A12, T2);        Mutiply(n, T2, B22, M2);        Add(n, A21, A22, T1);        Mutiply(n, T1, B11, M3);        Sub(n, B21, B11, T1);        Mutiply(n, A22, T1, M4);        Add(n, A11, A22, T1);        Add(n, B11, B22, T2);        Mutiply(n, T1, T2, M5);        Sub(n, A12, A22, T1);        Add(n, B21, B22, T2);        Mutiply(n, T1, T2, M6);        Sub(n, A11, A21, T1);        Add(n, B11, B12, T2);        Mutiply(n, T1, T2, M7);        Add(n, M5, M4, T1);        Sub(n, T1, M2, T2);        Add(n, T2, M6, M11);        Add(n, M1, M2, M12);        Add(n, M3, M4, M21);        Add(n, M5, M1, T1);        Sub(n, T1, M3, T2);        Sub(n, T2, M7, M22);        Merge(n, M, M11, M12, M21, M22);        for (int i = 0; i < n; i++)        {            delete []A11[i];            delete []A12[i];            delete []A21[i];            delete []A22[i];            delete []B11[i];            delete []B12[i];            delete []B21[i];            delete []B22[i];            delete []M11[i];            delete []M12[i];            delete []M21[i];            delete []M22[i];            delete []M1[i];            delete []M2[i];            delete []M3[i];            delete []M4[i];            delete []M5[i];            delete []M6[i];            delete []M7[i];            delete []T1[i];            delete []T2[i];        }    }}int main(){#ifdef DeBUGs    freopen("C:\\Users\\Sky\\Desktop\\1.in", "r", stdin);#endif    int n;    init(1024);    while (scanf("%d", &n) + 1)    {        int k = 1;        while (k < n)        {            k <<= 1;        }        clear(k);        getchar();        for (int i = 0; i < n; i++)            for (int j = 0; j < n; j++)            {                // scanf("%d", &A[i][j]);                A[i][j] = read();            }        for (int i = 0; i < n; i++)            for (int j = 0; j < n; j++)            {                // scanf("%d", &B[i][j]);                B[i][j] = read();;            }        n = k;        Mutiply(n, A, B, C);        for (int i = 0; i < n; i++)        {            printf("%d", (C[i][0] + mod) % mod);            for (int j = 1; j < n; j++)            {                printf(" %d", (C[i][j] + mod) % mod);            }            printf("\n");        }    }    // clear(n);    return 0;}



0 0