Strassen矩阵乘法

来源:互联网 发布:扔玻璃球 算法 编辑:程序博客网 时间:2024/06/04 00:29

矩阵乘法问题

输入:n*n的矩阵A和B

输出:A和B的乘积C=AB

由于计算每个C(i,j)需要n次乘法和n-1次加法,故需n3次乘法和n3-n2次加法,因此算法复杂度为 θ(n3)。


考虑分治算法,设n=2K,k>=0。如果n>2,则矩阵可划分为n/2*n/2的子矩阵:


上述分治算法的时间复杂度满足递归方程:

T(1) = θ(1)

T(n) = 8T(n/2) + θ(n2)   n>1

解得T(n) = θ(n3)。

Strasse将其改进,令:

则时间复杂度为

T(1) = θ(1)

T(n) =  7T(n/2) + θ(n2)   n>1 得T(n) = θ(n2.81)


import java.util.*; public class Strassen{     public Strassen(){        A = new int[NUMBER][NUMBER];        B = new int[NUMBER][NUMBER];        C = new int[NUMBER][NUMBER];    }     /**     * 输入矩阵函数     * */    public void input(int a[][]){        Scanner scanner = new Scanner(System.in);        for(int i = 0; i < a.length; i++) {            for(int j = 0; j < a[i].length; j++) {                a[i][j] = scanner.nextInt();            }        }    }     /**     * 输出矩阵     * */    public void output(int[][] resault){        for(int b[] : resault) {            for(int temp : b) {                System.out.print(temp + "   ");            }            System.out.println();        }    }     /**     * 矩阵乘法,此处只是定义了2*2矩阵的乘法     * */    public void Mul(int[][] first, int[][] second, int[][] resault){        for(int i = 0; i < 2; ++i) {            for(int j = 0; j < 2; ++j) {                resault[i][j] = 0;                for(int k = 0; k < 2; ++k) {                    resault[i][j] += first[i][k] * second[k][j];                }            }        }     }     /**     * 矩阵的加法运算     * */    public void Add(int[][] first, int[][] second, int[][] resault){        for(int i = 0; i < first.length; i++) {            for(int j = 0; j < first[i].length; j++) {                resault[i][j] = first[i][j] + second[i][j];            }        }    }         /**     * 矩阵的减法运算     * */    public void sub(int[][] first, int[][] second, int[][] resault){        for(int i = 0; i < first.length; i++) {            for(int j = 0; j < first[i].length; j++) {                resault[i][j] = first[i][j] - second[i][j];            }        }    }         /**     * strassen矩阵算法     * */    public void strassen(int[][] A, int[][] B, int[][] C){        //定义一些中间变量        int [][] M1=new int [NUMBER][NUMBER];        int [][] M2=new int [NUMBER][NUMBER];        int [][] M3=new int [NUMBER][NUMBER];        int [][] M4=new int [NUMBER][NUMBER];        int [][] M5=new int [NUMBER][NUMBER];        int [][] M6=new int [NUMBER][NUMBER];        int [][] M7=new int [NUMBER][NUMBER];                 int [][] C11=new int [NUMBER][NUMBER];        int [][] C12=new int [NUMBER][NUMBER];        int [][] C21=new int [NUMBER][NUMBER];        int [][] C22=new int [NUMBER][NUMBER];                 int [][] A11=new int [NUMBER][NUMBER];        int [][] A12=new int [NUMBER][NUMBER];        int [][] A21=new int [NUMBER][NUMBER];        int [][] A22=new int [NUMBER][NUMBER];                 int [][] B11=new int [NUMBER][NUMBER];        int [][] B12=new int [NUMBER][NUMBER];        int [][] B21=new int [NUMBER][NUMBER];        int [][] B22=new int [NUMBER][NUMBER];                 int [][] temp=new int [NUMBER][NUMBER];        int [][] temp1=new int [NUMBER][NUMBER];                                   if(A.length==2){            Mul(A, B, C);        }else{            //首先将矩阵A,B 分为4块            for(int i = 0; i < A.length/2; i++) {                for(int j = 0; j < A.length/2; j++) {                     A11[i][j]=A[i][j];                     A12[i][j]=A[i][j+A.length/2];                     A21[i][j]=A[i+A.length/2][j];                     A22[i][j]=A[i+A.length/2][j+A.length/2];                     B11[i][j]=B[i][j];                     B12[i][j]=B[i][j+A.length/2];                     B21[i][j]=B[i+A.length/2][j];                     B22[i][j]=B[i+A.length/2][j+A.length/2];                }            }            //计算M1            sub(B12, B22, temp);            Mul(A11, temp, M1);            //计算M2            Add(A11, A12, temp);            Mul(temp, B22, M2);            //计算M3            Add(A21, A22, temp);            Mul(temp, B11, M3);            //M4            sub(B21, B11, temp);            Mul(A22, temp, M4);            //M5            Add(A11, A22, temp1);            Add(B11, B22, temp);            Mul(temp1, temp, M5);            //M6            sub(A12, A22, temp1);            Add(B21, B22, temp);            Mul(temp1, temp, M6);            //M7            sub(A11, A21, temp1);            Add(B11, B12, temp);            Mul(temp1, temp, M7);                         //计算C11            Add(M5, M4, temp1);            sub(temp1, M2, temp);            Add(temp, M6, C11);            //计算C12            Add(M1, M2, C12);            //C21            Add(M3, M4, C21);            //C22            Add(M5, M1, temp1);            sub(temp1, M3, temp);            sub(temp, M7, C22);                         //结果送回C中            for(int i = 0; i < C.length/2; i++) {                for(int j = 0; j < C.length/2; j++) {                    C[i][j]=C11[i][j];                    C[i][j+C.length/2]=C12[i][j];                    C[i+C.length/2][j]=C21[i][j];                    C[i+C.length/2][j+C.length/2]=C22[i][j];                }            }                                  }             }     public static void main(String[] args){        Strassen demo=new Strassen();        System.out.println("输入矩阵A");        demo.input(A);        System.out.println("输入矩阵B");        demo.input(B);        demo.strassen(A, B, C);        demo.output(C);    }     private static int A[][];    private static int B[][];    private static int C[][];    private final static int NUMBER = 4;}





0 0
原创粉丝点击