两个矩阵相乘—Strassen算法与传统算法(要求矩阵阶n为2的幂)

来源:互联网 发布:淘宝网宝宝床 编辑:程序博客网 时间:2024/06/09 14:13
package com.work.home_2;

/**
 * 矩阵乘法
 * @author xuejun
 */
public class MatrixMultiplication {
    /**
     * strassen 算法(阶数n为2的幂)
     */
    public static int[][] strassens(int n, int[][] a, int[][] b) {
        int[][] c = new int[n][n];
        if (!(n == 2)) {
            // 如果n不是2,拆分矩阵 A
            int[][] a11 = new int[n / 2][n / 2];
            int[][] a12 = new int[n / 2][n / 2];
            int[][] a21 = new int[n / 2][n / 2];
            int[][] a22 = new int[n / 2][n / 2];
            // 拆分矩阵 B
            int[][] b11 = new int[n / 2][n / 2];
            int[][] b12 = new int[n / 2][n / 2];
            int[][] b21 = new int[n / 2][n / 2];
            int[][] b22 = new int[n / 2][n / 2];
            // 矩阵 A 左上角赋给矩阵 a11
            for (int i = 0; i < (n / 2); i++) {
                for (int j = 0; j < (n / 2); j++) {
                    a11[i][j] = a[i][j];
                }
            }
            // 矩阵 A 右上角赋给矩阵 a12
            for (int i = 0; i < (n / 2); i++) {
                for (int j = (n / 2), k = 0; j < n; j++, k++) {
                    a12[i][k] = a[i][j];
                }
            }
            // 矩阵 A 左下角赋给矩阵 a21
            for (int i = (n / 2), k = 0; i < n; i++, k++) {
                for (int j = 0; j < (n / 2); j++) {
                    a21[k][j] = a[i][j];
                }
            }
            // 矩阵 A 右下角赋给矩阵 a22
            for (int i = (n / 2), k = 0; i < n; i++, k++) {
                for (int j = (n / 2), g = 0; j < n; j++, g++) {
                    a22[k][g] = a[i][j];
                }
            }
            // 矩阵 B 左上角赋给矩阵 b11
            for (int i = 0; i < (n / 2); i++) {
                for (int j = 0; j < (n / 2); j++) {
                    b11[i][j] = b[i][j];
                }
            }
            // 矩阵 B 右上角赋给矩阵 b12
            for (int i = 0; i < (n / 2); i++) {
                for (int j = (n / 2), k = 0; j < n; j++, k++) {
                    b12[i][k] = b[i][j];
                }
            }
            // 矩阵 B 左下角赋给矩阵 b21
            for (int i = (n / 2), k = 0; i < n; i++, k++) {
                for (int j = 0; j < (n / 2); j++) {
                    b21[k][j] = b[i][j];
                }
            }
            // 矩阵 B 右下角赋给矩阵 b22
            for (int i = (n / 2), k = 0; i < n; i++, k++) {
                for (int j = (n / 2), g = 0; j < n; j++, g++) {
                    b22[k][g] = b[i][j];
                }
            }
            // strassen 算法的过程量
            int[][] m1 = strassens(n / 2, a11, matrixSubtraction(n / 2, b12, b22));
            int[][] m2 = strassens(n / 2, matrixAddition(n / 2, a11, a12), b22);
            int[][] m3 = strassens(n / 2, matrixAddition(n / 2, a21, a22), b11);
            int[][] m4 = strassens(n / 2, a22, matrixSubtraction(n / 2, b21, b11));
            int[][] m5 = strassens(n / 2, matrixAddition(n / 2, a11, a22), matrixAddition(n / 2, b11, b22));
            int[][] m6 = strassens(n / 2, matrixSubtraction(n / 2, a12, a22), matrixAddition(n / 2, b21, b22));
            int[][] m7 = strassens(n / 2, matrixSubtraction(n / 2, a21, a11), matrixAddition(n / 2, b11, b12));
            // 运用strassen思想将过程量简化
            int[][] c11 = matrixSubtraction(n / 2, matrixAddition(n / 2, m5, m4), matrixSubtraction(n / 2, m2, m6));
            int[][] c12 = matrixAddition(n / 2, m1, m2);
            int[][] c21 = matrixAddition(n / 2, m3, m4);
            int[][] c22 = matrixSubtraction(n / 2, matrixAddition(n / 2, m5, m1), matrixSubtraction(n / 2, m3, m7));

            // 把简化量整合给矩阵 C 赋值
            // 矩阵 C 左上角值为:
            for (int i = 0; i < (n / 2); i++) {
                for (int j = 0; j < (n / 2); j++) {
                    c[i][j] = c11[i][j];
                }
            }
            // 矩阵 C 右上角值为:
            for (int i = 0; i < (n / 2); i++) {
                for (int j = (n / 2), k = 0; j < n; j++, k++) {
                    c[i][j] = c12[i][k];
                }
            }
            // 矩阵 C 左下角值为:
            for (int i = (n / 2), k = 0; i < n; i++, k++) {
                for (int j = 0; j < (n / 2); j++) {
                    c[i][j] = c21[k][j];
                }
            }
            // 矩阵 C 右下角值为:
            for (int i = (n / 2), k = 0; i < n; i++, k++) {
                for (int j = (n / 2), g = 0; j < n; j++, g++) {
                    c[i][j] = c22[k][g];
                }
            }
            return c;
        } else {
            // 若果n是2阶,则直接调用strassen 算法(2阶)
            // System.out.println("调用 strassen 算法(2阶):");
            c = MatrixMultiplication.strassen(n, a, b);
            return c;
        }
    }

    /**
     * strassen 算法(2阶)
     */
    public static int[][] strassen(int n, int[][] a, int[][] b) {
        int[][] c = new int[n][n];
        if (n != 2) {
            System.out.println("非2阶,不可以调用");
            return c;
        } else {
            // 运用strassen思想计算2阶矩阵相乘
            int m1 = (a[0][0] + a[1][1]) * (b[0][0] + b[1][1]);
            int m2 = (a[1][0] + a[1][1]) * b[0][0];
            int m3 = a[0][0] * (b[0][1] - b[1][1]);
            int m4 = a[1][1] * (b[1][0] - b[0][0]);
            int m5 = (a[0][0] + a[0][1]) * b[1][1];
            int m6 = (a[1][0] - a[0][0]) * (b[0][0] + b[0][1]);
            int m7 = (a[0][1] - a[1][1]) * (b[1][0] + b[1][1]);

            c[0][0] = m1 + m4 - m5 + m7;
            c[0][1] = m3 + m5;
            c[1][0] = m2 + m4;
            c[1][1] = m1 + m3 - m2 + m6;

            // 输出结果
            // System.out.println("strassen算法(2阶):俩矩阵相乘 A * B = ");
            // for (int i = 0; i < c.length; i++) {
            // for (int j = 0; j < c[i].length; j++) {
            // System.out.print(c[i][j] + ",\t ");
            // if (j == (n - 1))
            // System.out.println();
            // }
            // }
            return c;
        }
    }

    /**
     * 传统矩阵相乘,亦可用来验证strassen算法
     */
    public static int[][] traditionMu(int n, int[][] a, int[][] b) {
        int[][] c = new int[n][n];
        for (int i = 0; i < a.length; i++) {
            for (int j = 0; j < b.length; j++) {
                for (int k = 0; k < n; k++) {
                    c[i][j] += a[i][k] * b[k][j];
                }
            }
        }

        // 输出结果
        System.out.println("传统算法:俩矩阵相乘 A * B = ");
        for (int i = 0; i < c.length; i++) {
            for (int j = 0; j < c[i].length; j++) {
                System.out.print(c[i][j] + ",\t ");
                if (j == (n - 1))
                    System.out.println();
            }
        }
        return c;
    }

    /**
     * 两个矩阵相加
     */
    public static int[][] matrixAddition(int n, int[][] a, int[][] b) {
        // c 矩阵作为结果矩阵返回
        int[][] c = new int[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                c[i][j] = a[i][j] + b[i][j];
            }
        }
        // 输出结果
        // System.out.println("俩矩阵相加 A + B = ");
        // for (int i = 0; i < c.length; i++) {
        // for (int j = 0; j < c[i].length; j++) {
        // System.out.print(c[i][j] + ",\t ");
        // if (j == (n - 1))
        // System.out.println();
        // }
        // }
        return c;
    }

    /**
     * 两个矩阵减法
     */
    public static int[][] matrixSubtraction(int n, int[][] a, int[][] b) {
        // c 矩阵作为结果矩阵返回
        int[][] c = new int[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                c[i][j] = a[i][j] - b[i][j];
            }
        }
        // 输出结果
        // System.out.println("俩矩阵相减 A - B = ");
        // for (int i = 0; i < c.length; i++) {
        // for (int j = 0; j < c[i].length; j++) {
        // System.out.print(c[i][j] + ",\t ");
        // if (j == (n - 1))
        // System.out.println();
        // }
        // }
        return c;
    }
}


package com.work.home_2;

import java.util.Scanner;

/**
 * 测试函数
 *
 * @author xuejun
 */
public class Test {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        System.out.println("请输入矩阵的阶数(n = 2^x)x = :");
        int n = (int) Math.pow(2, sc.nextInt());
        sc.close();
        // 测试矩阵 A
        int[][] a = new int[n][n];
        // 给矩阵 A 赋值
        System.out.println("矩阵 A 阶数为:" + n + ",值为:");
        for (int i = 0; i < a.length; i++) {
            for (int j = 0; j < a[i].length; j++) {
                a[i][j] = (int) (Math.random() * 10);
                System.out.print(a[i][j] + ",\t ");
                if (j == (n - 1))
                    System.out.println();
            }
        }
        // 给矩阵 B 赋值
        int[][] b = new int[n][n];
        // 给矩阵 B 赋值
        System.out.println("矩阵 B 阶数为:" + n + ",值为:");
        for (int i = 0; i < b.length; i++) {
            for (int j = 0; j < b[i].length; j++) {
                b[i][j] = (int) (Math.random() * 10);
                System.out.print(b[i][j] + ",\t ");
                if (j == (n - 1))
                    System.out.println();
            }
        }
        // 接收结果矩阵

        // 测试传统算法
        MatrixMultiplication.traditionMu(n, a, b);
        // 测试两矩阵加法
        // MatrixMultiplication.matrixAddition(n, a, b);
        // 测试两矩阵减法
        // MatrixMultiplication.strassen(n, a, b);
        // 测试strassen算法
        int[][] c = MatrixMultiplication.strassens(n, a, b);

        // 输出结果
        System.out.println("strassen算法(阶数为2的幂):俩矩阵相乘 A * B = ");
        for (int i = 0; i < c.length; i++) {
            for (int j = 0; j < c[i].length; j++) {
                System.out.print(c[i][j] + ",\t ");
                if (j == (n - 1))
                    System.out.println();
            }
        }
    }
}

0 0