贪心算法-Strassen矩阵乘法

来源:互联网 发布:钢材软件 编辑:程序博客网 时间:2024/03/29 22:26

两个矩阵的乘法学过线性代数的都知道怎么求,一般来说复杂度为O(N^3).直接给出标准的算法

代码:

public class MartixMultiply { public static int[][] multiply(int[][] a, int[][] b) {     int n = a.length;     int[][] c = new int[n][n];     for (int i = 0; i < n; i++)         // 初始化         for (int j = 0; j < n; j++)             c[i][j] = 0;     for (int i = 0; i < n; i++)         for (int j = 0; j < n; j++)             for (int k = 0; k < n; k++)                 c[i][j] += a[i][k] * b[k][j];     return c; } public static void main(String[] args) {     int[][] a = { { 1, 2 }, { 3, 4 } };     int[][] b = { { 3, 4 }, { 7, 2 } };     int[][] c = multiply(a, b);     System.out.println(c[0][0] + " " + c[0][1] + " " + c[1][0] + " "             + c[1][1]); }}

Strassen提出了算法打破了O(N^3)的屏障.用到分治算法,把矩阵分为4块.

这里写图片描述
这里写图片描述

其中
这里写图片描述
可以得到递推关系T(N)=7T(N/2)+O(N²),依据主定理得到解T(N)=O(N^2.81).
这儿不做出证明,显然这用到了分治法的思想

代码:

public class MartixMultiply { public static int[][] StrassenMultiply(int[][] a, int[][] b) {     int[][] result = new int[a.length][b.length];     if (a.length == 2)         return multiply(a, b);// 如果是2阶的 就结束递归 用传统方法     // a的四个子矩阵     int[][] A00 = divide(a, 1);     int[][] A01 = divide(a, 2);     int[][] A10 = divide(a, 3);     int[][] A11 = divide(a, 4);     // b的四个子矩阵     int[][] B00 = divide(b, 1);     int[][] B01 = divide(b, 2);     int[][] B10 = divide(b, 3);     int[][] B11 = divide(b, 4);     int[][] m1 = StrassenMultiply(addArrays(A00, A11), addArrays(B00, B11));     int[][] m2 = StrassenMultiply(addArrays(A10, A11), B00);     int[][] m3 = StrassenMultiply(A00, subArrays(B01, B11));     int[][] m4 = StrassenMultiply(A11, subArrays(B10, B00));     int[][] m5 = StrassenMultiply(addArrays(A00, A01), B11);     int[][] m6 = StrassenMultiply(subArrays(A10, A00), addArrays(B00, B01));     int[][] m7 = StrassenMultiply(subArrays(A01, A11), addArrays(B10, B11));     int[][] C00 = addArrays(m7, subArrays(addArrays(m1, m4), m5));// m1+m4-m5+m7     int[][] C01 = addArrays(m3, m5); // m3+m5     int[][] C10 = addArrays(m2, m4); // m2+m4     int[][] C11 = addArrays(m6, subArrays(addArrays(m1, m3), m2));// m1+m3-m2+m6     // 将四个矩阵合并起来     Merge(result, C00, 1);     Merge(result, C01, 2);     Merge(result, C10, 3);     Merge(result, C11, 4);     return result; } // /分割得到子矩阵 private static int[][] divide(int[][] a, int flag) {     int[][] result = new int[a.length / 2][a.length / 2];     switch (flag) {     case 1:         for (int i = 0; i < a.length / 2; i++)             for (int j = 0; j < a.length / 2; j++)                 result[i][j] = a[i][j];         break;     case 2:         for (int i = 0; i < a.length / 2; i++)             for (int j = a.length / 2; j < a.length; j++)                 result[i][j - a.length / 2] = a[i][j];         break;     case 3:         for (int i = a.length / 2; i < a.length; i++)             for (int j = 0; j < a.length / 2; j++)                 result[i - a.length / 2][j] = a[i][j];         break;     case 4:         for (int i = a.length / 2; i < a.length; i++)             for (int j = a.length / 2; j < a.length; j++)                 result[i - a.length / 2][j - a.length / 2] = a[i][j];         break;     }     return result; } // 矩阵加法 private static int[][] addArrays(int[][] a, int[][] b) {     int[][] result = new int[a.length][a.length];     for (int i = 0; i < result.length; i++) {         for (int j = 0; j < result.length; j++) {             result[i][j] = a[i][j] + b[i][j];         }     }     return result; } // 矩阵减法 private static int[][] subArrays(int[][] a, int[][] b) {     int[][] result = new int[a.length][a.length];     for (int i = 0; i < result.length; i++) {         for (int j = 0; j < result.length; j++) {             result[i][j] = a[i][j] - b[i][j];         }     }     return result; } // 将b复制到a的指定位置 private static void Merge(int[][] a, int[][] b, int flag) {     switch (flag) {     case 1:         for (int i = 0; i < a.length / 2; i++)             for (int j = 0; j < a.length / 2; j++)                 a[i][j] = b[i][j];         break;     case 2:         for (int i = 0; i < a.length / 2; i++)             for (int j = a.length / 2; j < a.length; j++)                 a[i][j] = b[i][j - a.length / 2];         break;     case 3:         for (int i = a.length / 2; i < a.length; i++)             for (int j = 0; j < a.length / 2; j++)                 a[i][j] = b[i - a.length / 2][j];         break;     case 4:         for (int i = a.length / 2; i < a.length; i++)             for (int j = a.length / 2; j < a.length; j++)                 a[i][j] = b[i - a.length / 2][j - a.length / 2];         break;     } } // 常规做法 public static int[][] multiply(int[][] a, int[][] b) {     int n = a.length;     int[][] c = new int[n][n];     for (int i = 0; i < n; i++)         // Initialization         for (int j = 0; j < n; j++)             c[i][j] = 0;     for (int i = 0; i < n; i++)         for (int j = 0; j < n; j++)             for (int k = 0; k < n; k++)                 c[i][j] += a[i][k] * b[k][j];     return c; } public static void main(String[] args) {     int[][] a = { { 1, 2, 6, 7 }, { 3, 4, 5, 4 }, { 5, 8, 3, 8 },             { -6, 4, 3, 9 } };     int[][] b = { { 3, 4, 9, 0 }, { 7, 2, -5, -6 }, { 0, 7, -4, 6 },             { -6, 3, -5, 4 } };     int[][] c = multiply(a, b);     System.out.println(c[0][0] + " " + c[0][1] + " " + c[1][0] + " "             + c[1][1]); }}
0 0