两个矩阵相乘—Strassen算法与传统算法(要求矩阵阶n为2的幂)
来源:互联网 发布:淘宝网宝宝床 编辑:程序博客网 时间:2024/06/09 14:13
/**
* 矩阵乘法
* @author xuejun
*/
public class MatrixMultiplication {
/**
* strassen 算法(阶数n为2的幂)
*/
public static int[][] strassens(int n, int[][] a, int[][] b) {
int[][] c = new int[n][n];
if (!(n == 2)) {
// 如果n不是2,拆分矩阵 A
int[][] a11 = new int[n / 2][n / 2];
int[][] a12 = new int[n / 2][n / 2];
int[][] a21 = new int[n / 2][n / 2];
int[][] a22 = new int[n / 2][n / 2];
// 拆分矩阵 B
int[][] b11 = new int[n / 2][n / 2];
int[][] b12 = new int[n / 2][n / 2];
int[][] b21 = new int[n / 2][n / 2];
int[][] b22 = new int[n / 2][n / 2];
// 矩阵 A 左上角赋给矩阵 a11
for (int i = 0; i < (n / 2); i++) {
for (int j = 0; j < (n / 2); j++) {
a11[i][j] = a[i][j];
}
}
// 矩阵 A 右上角赋给矩阵 a12
for (int i = 0; i < (n / 2); i++) {
for (int j = (n / 2), k = 0; j < n; j++, k++) {
a12[i][k] = a[i][j];
}
}
// 矩阵 A 左下角赋给矩阵 a21
for (int i = (n / 2), k = 0; i < n; i++, k++) {
for (int j = 0; j < (n / 2); j++) {
a21[k][j] = a[i][j];
}
}
// 矩阵 A 右下角赋给矩阵 a22
for (int i = (n / 2), k = 0; i < n; i++, k++) {
for (int j = (n / 2), g = 0; j < n; j++, g++) {
a22[k][g] = a[i][j];
}
}
// 矩阵 B 左上角赋给矩阵 b11
for (int i = 0; i < (n / 2); i++) {
for (int j = 0; j < (n / 2); j++) {
b11[i][j] = b[i][j];
}
}
// 矩阵 B 右上角赋给矩阵 b12
for (int i = 0; i < (n / 2); i++) {
for (int j = (n / 2), k = 0; j < n; j++, k++) {
b12[i][k] = b[i][j];
}
}
// 矩阵 B 左下角赋给矩阵 b21
for (int i = (n / 2), k = 0; i < n; i++, k++) {
for (int j = 0; j < (n / 2); j++) {
b21[k][j] = b[i][j];
}
}
// 矩阵 B 右下角赋给矩阵 b22
for (int i = (n / 2), k = 0; i < n; i++, k++) {
for (int j = (n / 2), g = 0; j < n; j++, g++) {
b22[k][g] = b[i][j];
}
}
// strassen 算法的过程量
int[][] m1 = strassens(n / 2, a11, matrixSubtraction(n / 2, b12, b22));
int[][] m2 = strassens(n / 2, matrixAddition(n / 2, a11, a12), b22);
int[][] m3 = strassens(n / 2, matrixAddition(n / 2, a21, a22), b11);
int[][] m4 = strassens(n / 2, a22, matrixSubtraction(n / 2, b21, b11));
int[][] m5 = strassens(n / 2, matrixAddition(n / 2, a11, a22), matrixAddition(n / 2, b11, b22));
int[][] m6 = strassens(n / 2, matrixSubtraction(n / 2, a12, a22), matrixAddition(n / 2, b21, b22));
int[][] m7 = strassens(n / 2, matrixSubtraction(n / 2, a21, a11), matrixAddition(n / 2, b11, b12));
// 运用strassen思想将过程量简化
int[][] c11 = matrixSubtraction(n / 2, matrixAddition(n / 2, m5, m4), matrixSubtraction(n / 2, m2, m6));
int[][] c12 = matrixAddition(n / 2, m1, m2);
int[][] c21 = matrixAddition(n / 2, m3, m4);
int[][] c22 = matrixSubtraction(n / 2, matrixAddition(n / 2, m5, m1), matrixSubtraction(n / 2, m3, m7));
// 把简化量整合给矩阵 C 赋值
// 矩阵 C 左上角值为:
for (int i = 0; i < (n / 2); i++) {
for (int j = 0; j < (n / 2); j++) {
c[i][j] = c11[i][j];
}
}
// 矩阵 C 右上角值为:
for (int i = 0; i < (n / 2); i++) {
for (int j = (n / 2), k = 0; j < n; j++, k++) {
c[i][j] = c12[i][k];
}
}
// 矩阵 C 左下角值为:
for (int i = (n / 2), k = 0; i < n; i++, k++) {
for (int j = 0; j < (n / 2); j++) {
c[i][j] = c21[k][j];
}
}
// 矩阵 C 右下角值为:
for (int i = (n / 2), k = 0; i < n; i++, k++) {
for (int j = (n / 2), g = 0; j < n; j++, g++) {
c[i][j] = c22[k][g];
}
}
return c;
} else {
// 若果n是2阶,则直接调用strassen 算法(2阶)
// System.out.println("调用 strassen 算法(2阶):");
c = MatrixMultiplication.strassen(n, a, b);
return c;
}
}
/**
* strassen 算法(2阶)
*/
public static int[][] strassen(int n, int[][] a, int[][] b) {
int[][] c = new int[n][n];
if (n != 2) {
System.out.println("非2阶,不可以调用");
return c;
} else {
// 运用strassen思想计算2阶矩阵相乘
int m1 = (a[0][0] + a[1][1]) * (b[0][0] + b[1][1]);
int m2 = (a[1][0] + a[1][1]) * b[0][0];
int m3 = a[0][0] * (b[0][1] - b[1][1]);
int m4 = a[1][1] * (b[1][0] - b[0][0]);
int m5 = (a[0][0] + a[0][1]) * b[1][1];
int m6 = (a[1][0] - a[0][0]) * (b[0][0] + b[0][1]);
int m7 = (a[0][1] - a[1][1]) * (b[1][0] + b[1][1]);
c[0][0] = m1 + m4 - m5 + m7;
c[0][1] = m3 + m5;
c[1][0] = m2 + m4;
c[1][1] = m1 + m3 - m2 + m6;
// 输出结果
// System.out.println("strassen算法(2阶):俩矩阵相乘 A * B = ");
// for (int i = 0; i < c.length; i++) {
// for (int j = 0; j < c[i].length; j++) {
// System.out.print(c[i][j] + ",\t ");
// if (j == (n - 1))
// System.out.println();
// }
// }
return c;
}
}
/**
* 传统矩阵相乘,亦可用来验证strassen算法
*/
public static int[][] traditionMu(int n, int[][] a, int[][] b) {
int[][] c = new int[n][n];
for (int i = 0; i < a.length; i++) {
for (int j = 0; j < b.length; j++) {
for (int k = 0; k < n; k++) {
c[i][j] += a[i][k] * b[k][j];
}
}
}
// 输出结果
System.out.println("传统算法:俩矩阵相乘 A * B = ");
for (int i = 0; i < c.length; i++) {
for (int j = 0; j < c[i].length; j++) {
System.out.print(c[i][j] + ",\t ");
if (j == (n - 1))
System.out.println();
}
}
return c;
}
/**
* 两个矩阵相加
*/
public static int[][] matrixAddition(int n, int[][] a, int[][] b) {
// c 矩阵作为结果矩阵返回
int[][] c = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
c[i][j] = a[i][j] + b[i][j];
}
}
// 输出结果
// System.out.println("俩矩阵相加 A + B = ");
// for (int i = 0; i < c.length; i++) {
// for (int j = 0; j < c[i].length; j++) {
// System.out.print(c[i][j] + ",\t ");
// if (j == (n - 1))
// System.out.println();
// }
// }
return c;
}
/**
* 两个矩阵减法
*/
public static int[][] matrixSubtraction(int n, int[][] a, int[][] b) {
// c 矩阵作为结果矩阵返回
int[][] c = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
c[i][j] = a[i][j] - b[i][j];
}
}
// 输出结果
// System.out.println("俩矩阵相减 A - B = ");
// for (int i = 0; i < c.length; i++) {
// for (int j = 0; j < c[i].length; j++) {
// System.out.print(c[i][j] + ",\t ");
// if (j == (n - 1))
// System.out.println();
// }
// }
return c;
}
}
package com.work.home_2;
import java.util.Scanner;
/**
* 测试函数
*
* @author xuejun
*/
public class Test {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
System.out.println("请输入矩阵的阶数(n = 2^x)x = :");
int n = (int) Math.pow(2, sc.nextInt());
sc.close();
// 测试矩阵 A
int[][] a = new int[n][n];
// 给矩阵 A 赋值
System.out.println("矩阵 A 阶数为:" + n + ",值为:");
for (int i = 0; i < a.length; i++) {
for (int j = 0; j < a[i].length; j++) {
a[i][j] = (int) (Math.random() * 10);
System.out.print(a[i][j] + ",\t ");
if (j == (n - 1))
System.out.println();
}
}
// 给矩阵 B 赋值
int[][] b = new int[n][n];
// 给矩阵 B 赋值
System.out.println("矩阵 B 阶数为:" + n + ",值为:");
for (int i = 0; i < b.length; i++) {
for (int j = 0; j < b[i].length; j++) {
b[i][j] = (int) (Math.random() * 10);
System.out.print(b[i][j] + ",\t ");
if (j == (n - 1))
System.out.println();
}
}
// 接收结果矩阵
// 测试传统算法
MatrixMultiplication.traditionMu(n, a, b);
// 测试两矩阵加法
// MatrixMultiplication.matrixAddition(n, a, b);
// 测试两矩阵减法
// MatrixMultiplication.strassen(n, a, b);
// 测试strassen算法
int[][] c = MatrixMultiplication.strassens(n, a, b);
// 输出结果
System.out.println("strassen算法(阶数为2的幂):俩矩阵相乘 A * B = ");
for (int i = 0; i < c.length; i++) {
for (int j = 0; j < c[i].length; j++) {
System.out.print(c[i][j] + ",\t ");
if (j == (n - 1))
System.out.println();
}
}
}
}
- 两个矩阵相乘—Strassen算法与传统算法(要求矩阵阶n为2的幂)
- 两个矩阵相乘—Strassen算法(矩阵为偶数阶方阵)
- 矩阵相乘的Strassen算法
- Strassen矩阵相乘算法
- Strassen矩阵相乘算法
- 矩阵相乘算法——Strassen算法
- 矩阵分块相乘的Strassen算法
- 计算机算法:Strassen矩阵相乘算法
- 矩阵相乘的快速算法(施特拉森-Strassen算法)
- 算法导论 矩阵相乘(Strassen方法)
- 矩阵相乘Strassen算法Java实现
- 两个矩阵相乘算法
- 矩阵乘法的Strassen算法
- 矩阵乘法的Strassen算法
- Strassen矩阵算法的实现
- strassen算法(矩阵乘法)
- 两个二维矩阵相乘的算法
- [算法系列之十五]Strassen矩阵相乘算法
- 357. Count Numbers with Unique Digits
- HDOJ 4635: Strongly connected 【强连通】
- 算法学习之动态规划(leetcode 304. Range Sum Query 2D - Immutable)
- 游艇租用
- SSM搭建-Maven创建第一个web项目(22-1)
- 两个矩阵相乘—Strassen算法与传统算法(要求矩阵阶n为2的幂)
- tomcat添加https访问支持
- SSM框架项目搭建系列(五)—Spring之Bean的注解注入
- git:拉库的指定分支
- LeetCode No.45 Jump Game II
- The Suspects(并查集)
- i.MX6UL -- 架构图
- 手机设备、平板、桌面设备的相关信息
- apache缓存