基于Strassen算法采用分治的矩阵乘法cpp实现

来源:互联网 发布:linux怎么卸载jenkins 编辑:程序博客网 时间:2024/04/26 23:58

直接上代码。
注意:只支持维度为2的幂次的方阵相乘。

#include <cstdio>#define maxn 50struct matrix{    int con [maxn][maxn];    int size = 0 ; //规定一定是n * n矩阵} m1, m2;matrix add(matrix A, matrix B, int len ) {    matrix res;    for (int i = 0; i < len; i++)    {        for (int j = 0; j < len; j++)        {            res.con[i][j] = A.con[i][j] + B.con[i][j];        }    }    return res;}matrix sub(matrix A, matrix B, int len ) {    matrix res;    for (int i = 0; i < len; i++)    {        for (int j = 0; j < len; j++)        {            res.con[i][j] = A.con[i][j] - B.con[i][j];            // printf("%d\n",res.con[i][j] );        }    }    return res;}void print_it(matrix a, int n) {    for (int i = 0; i < n ; i++)    {        for (int j = 0; j < n ; j++)        {            printf("%d ", a.con[i][j]);        }        printf("\n");    }}matrix create(matrix input, int r1, int r2, int c1, int c2) {    int ii = 0, jj = 0;    matrix res;    for (int i = r1; i <= r2 && ii < r2 - r1; i++)    {        for (int j = c1; j < c2 && jj < c2 - c1; j++)        {            res.con[ii][jj] = input.con[i][j];            jj++;        }        jj = 0;        ii++;    }    return res;}matrix multi(matrix A, matrix B, int r1, int c1, int len) {    // 0 0    if (len == 1)    {        matrix ender ;        ender. con[0][0] = A.con[0][0] * B.con[0][0];        return ender;    } else {        matrix a, b, c, d, e, f, g, h;        int ii = 0, jj = 0;        a = create(A, r1, r1 + len / 2, c1, c1 + len / 2);        e = create(B, r1, r1 + len / 2, c1, c1 + len / 2);        b = create(A, r1, r1 + len / 2, c1 + len / 2, len);        f = create(B, r1, r1 + len / 2, c1 + len / 2, len);        c = create(A, r1 + len / 2, len, c1, c1 + len / 2);        g = create(B, r1 + len / 2, len, c1, c1 + len / 2);        d = create(A, r1 + len / 2 , len , c1 + len / 2, len);        h = create(B, r1 + len / 2 , len , c1 + len / 2 , len);        matrix p1, p2, p3, p4, p5, p6, p7;        p1 = multi(a, sub(f, h, len / 2), 0, 0, len / 2);        p2 = multi(add(a, b, len / 2), h, 0, 0, len / 2);        p3 = multi(add(c, d, len / 2), e, 0, 0, len / 2);        p4 = multi(d, sub(g, e, len / 2), 0, 0, len / 2);        p5 = multi(add(a, d, len / 2), add(e, h, len / 2), 0, 0, len / 2);        p6 = multi(sub(b, d, len / 2), add(g, h, len / 2), 0, 0, len / 2);        p7 = multi(sub(a, c, len / 2), add(e, f, len / 2), 0, 0, len / 2);        matrix r , s, t, u;        r = sub(add(add(p5, p4, len / 2), p6, len / 2), p2, len / 2);        s = add(p1, p2, len / 2);        t = add(p3, p4, len / 2);        u = sub(add(p5, p1, len / 2), add(p3, p7, len / 2), len / 2);        matrix rr;        // printf("--\n");        // print_it(r, len / 2);        // printf("--\n");        // print_it(s, len / 2);        // printf("--\n");        // print_it(t, len / 2);        // printf("--\n");        // print_it(u, len / 2);        for (int j = 0 ; j < len / 2; j++) {            for (int jj = 0 ; jj < len / 2; jj++) {                rr.con[j][jj] = r.con[j][jj];            }        }        for (int j = 0 ; j < len / 2; j++) {            for (int jj = 0 ; jj < len / 2; jj++) {                rr.con[j][jj + len / 2] = s.con[j][jj];            }        }        for (int j = 0 ; j < len / 2; j++) {            for (int jj = 0 ; jj < len / 2; jj++) {                rr.con[j + len / 2][jj] = t.con[j][jj];            }        }        for (int j = 0 ; j < len / 2; j++) {            for (int jj = 0 ; jj < len / 2; jj++) {                rr.con[j + len / 2][jj + len / 2] = u.con[j][jj];            }        }        return rr;    }}int main(int argc, char const *argv[]){    int n ;    printf("输入矩阵的维数:\n");    scanf("%d", &n);    printf("第一个矩阵:\n");    for (int i = 0; i < n ; i++)    {        for (int j = 0; j < n ; j++)        {            scanf("%d", &m1.con[i][j]);        }    }    printf("第二个矩阵:\n");    for (int i = 0; i < n ; i++)    {        for (int j = 0; j < n ; j++)        {            scanf("%d", &m2.con[i][j]);        }    }    printf("计算结果:\n");    print_it(multi(m1, m2, 0, 0, n), n);    return 0;}
原创粉丝点击