《算法导论》学习心得(二)—— 矩阵乘法之Strassen算法
来源:互联网 发布:获取键值对java 编辑:程序博客网 时间:2024/05/01 02:34
个人blog迁移到www.forwell.me
在开始之前,请点击下载源码。提起矩阵乘法,你也许会说不就是三次循环就解决问题了吗,这有什么好说的。是啊,三个循环确实是完事了,时间效率是O(n^3),这是我们上第一节线代老师就清清楚楚的告诉我们的,但是他没有告诉你还有比这更好的矩阵乘法,时间效率为,也许你觉得这没有什么,就提高了0.2几,没啥,但是你想过没有,当N=100,10000的时候呢,Strassen算法和传统方法又有多少差别呢,让我们来看一下Strassen算法和传统方法的效率对比图:
通过图我们会发现Strassen算法在N超过50的时候就开始表现出明显的优势,然而现实生产中矩阵都是上百阶的,那Strassen算法更是占有绝对的优势,所以我们今天就很有必要学习Strassen算法,那下面就开始进入正题。
Strassen算法
1969年,德国的一位数学家Strassen证明O(N^3)的解法并不是矩阵乘法的最优算法,他做了一系列工作使得最终的时间复杂度降低到了O(n^2.80)。那他是怎么做到的呢?对于矩阵乘法 C = A × B,通常的做法是将矩阵进行分块相乘,如下图所示:
从上图可以看出这种分块相乘总共用了8次乘法,要改进算法计算时间的复杂度,必须减少乘法运算次数。按分治法的思想,Strassen提出一种新的方法,用7次乘法完成2阶矩阵的乘法,算法如下:
M1 = A11(B12 - B12)
M2 = (A11 + A12)B22
M3 = (A21 + A22)B11
M4 = A22(B21 - B11)
M5 = (A11 + A22)(B11 + B22)
M6 = (A12 - A22)(B21 + B22)
M7 = (A11 - A21)(B11 + B12)
完成了7次乘法,再做如下加法:
C11 = M5 + M4 - M2 + M6
C12 = M1 + M2
C21 = M3 + M4
C22 = M5 + M1 - M3 - M7
全部计算使用了7次乘法和18次加减法,计算时间降低到O(nE2.81)。计算复杂性得到较大改进。具体代码实现如下:
//Strassen二阶矩阵的乘法static int[][] twostrassenMatrixMultiply(int [][]x,int [][]y) //阶数为2的矩阵乘法 { int matrixXColumnLength = x[0].length;int matrixYRowLength = x.length;//获取矩阵的行长度if(matrixXColumnLength!=matrixYRowLength){throw new RuntimeException("matrixXColumnLength!=matrixYRowLength,无法进行乘法计算!");}int p1,p2,p3,p4,p5,p6,p7;//这些都是按照算法定义进行的int [][]result = new int[2][2];p1=(y[0][1] - y[1][1]) * x[0][0]; p2=y[1][1] * (x[0][0] + x[0][1]); p3=(x[1][0] + x[1][1]) * y[0][0]; p4=x[1][1] * (y[1][0] - y[0][0]); p5=(x[0][0] + x[1][1]) * (y[0][0]+y[1][1]); p6=(x[0][1] - x[1][1]) * (y[1][0]+y[1][1]); p7=(x[0][0] - x[1][0]) * (y[0][0]+y[0][1]); result[0][0] = p5 + p4 - p2 + p6; result[0][1] = p1 + p2; result[1][0] = p3 + p4;result[1][1] = p5 + p1 - p3 - p7;return result; }整个计算过程为:
static int[][] strassenMatrixMultiply(int [][]x,int [][]y) //矩阵乘法方法 { if(x.length==2) { return twostrassenMatrixMultiply(x,y);} else { int matrixLength = x.length/2;int[][] a11,a12,a21,a22;a11 = new int[matrixLength][matrixLength];a12 = new int[matrixLength][matrixLength];a21 = new int[matrixLength][matrixLength];a22 = new int[matrixLength][matrixLength];int[][] b11,b12,b21,b22;b11 = new int[matrixLength][matrixLength];b12 = new int[matrixLength][matrixLength];b21 = new int[matrixLength][matrixLength];b22 = new int[matrixLength][matrixLength];int[][] c11,c12,c21,c22,c; c11 = new int[matrixLength][matrixLength];c12 = new int[matrixLength][matrixLength];c21 = new int[matrixLength][matrixLength];c22 = new int[matrixLength][matrixLength];c = new int[2*matrixLength][2*matrixLength];int[][] m1,m2,m3,m4,m5,m6,m7;divide(x,a11,a12,a21,a22); //拆分A、B、C矩阵 divide(y,b11,b12,b21,b22); divide(c,c11,c12,c21,c22);m1=strassenMatrixMultiply(a11,matrixMinus(b12,b22)); m2=strassenMatrixMultiply(matrixPlus(a11,a12),b22);m3=strassenMatrixMultiply(matrixPlus(a21,a22),b11);m4=strassenMatrixMultiply(a22,matrixMinus(b21,b11)); m5=strassenMatrixMultiply(matrixPlus(a11,a22),matrixPlus(b11,b22)); m6=strassenMatrixMultiply(matrixMinus(a12,a22),matrixPlus(b21,b22)); m7=strassenMatrixMultiply(matrixMinus(a11,a21),matrixPlus(b11,b12)); c11=matrixPlus(matrixMinus(matrixPlus(m5,m4),m2),m6); c12=matrixPlus(m1,m2); c21=matrixPlus(m3,m4); c22=matrixMinus(matrixMinus(matrixPlus(m5,m1),m3),m7);c=merge(c11,c12,c21,c22); //合并C矩阵 return c; } }上面就是整个算法的实现过程,欢迎大家前来讨论tangboneu@fo
完整代码:
package com.tangbo;import java.util.Random;import java.util.Scanner;/* * @Author:唐波 * Strassen矩阵乘法 * 2014.10.31 * 程序对比了传统方法和Strassen算法计算的结果是否相等 * 算法来源:1969年,德国的一位数学家Strassen证明O(N^3)的解法并不是矩阵乘法的最优算法,他做了一系列工作使得最终的时间复杂度降低到了O(n^2.80) */public class SquareMatrixMultiply {static Random random = new Random();static Scanner in;public static void main(String[] args) { int matrixLength=0;in = new Scanner(System.in); System.out.print("输入矩阵的阶数: "); matrixLength = in.nextInt();if(isEven(matrixLength)==0){int [][]x=productMatrix(matrixLength);int [][]y=productMatrix(matrixLength);System.out.println("x矩阵:");printMatrix(x);System.out.println("y矩阵:");printMatrix(y);int [][]strassenResult =strassenMatrixMultiply(x,y);//Strassen计算结果System.out.println("Strassen计算结果:");printMatrix(strassenResult);int [][] forceResult = forceMatrixMultiply(x, y);//传统方法计算结果System.out.println("传统计算结果:");printMatrix(forceResult);boolean isEqual = isEqual(forceResult, strassenResult);//比较两种计算结果是否相等if(isEqual){System.out.println("两个计算结果相等!");}else{System.err.println("两个计算结果不相等!");System.exit(0);//程序退出}}else{System.out.println("矩阵不是2^k方阵,无法计算!");}}static boolean isEqual(int [][]x,int [][]y)//遍历判断两个矩阵是否相等{boolean equal=true;for(int i =0;i<x.length;i++){for(int j=0;j<x[0].length;j++){if(x[i][j]!=y[i][j]){equal=false;}}}return equal;}static int isEven(int n)//判断输入矩阵阶数是否为2^k{ int a = 1,temp=n; while(temp%2==0) { if(temp%2==0) temp/=2; } if(temp==1) a=0; return a;} static int[][] productMatrix(int matrixLength)//自动生成矩阵{int matrix[][] = new int[matrixLength][matrixLength];//初始化矩阵for(int i=0;i<matrixLength;i++){for(int j=0;j<matrixLength;j++){matrix[i][j] = (int)(Math.random()*10);}}System.out.println();return matrix;}static void printMatrix(int matrix[][])//矩阵打印函数{int matrixRowLength = matrix.length;//获取矩阵的行数int matrixColumnLength = matrix[0].length;//获取矩阵的列数for(int i=0;i<matrixRowLength;i++){for(int j=0;j<matrixColumnLength;j++){System.out.print(matrix[i][j]+" ");}System.out.println();}}static int[][] matrixPlus(int[][] x,int[][] y) //矩阵加法 { int matrixXRowLength = x.length;//获取矩阵的行长度int matrixXColumnLength = x[0].length;int matrixYRowLength = x.length;//获取矩阵的行长度int matrixYColumnLength = x[0].length;if(matrixXColumnLength!=matrixYColumnLength || matrixXRowLength!=matrixYRowLength)//判断矩阵是否同型{throw new RuntimeException("矩阵不同型,无法进行加法计算!");}int[][] result = new int[matrixXRowLength][matrixXColumnLength];for(int i=0;i<matrixXColumnLength;i++){for (int j = 0; j < matrixXColumnLength; j++) {result[i][j] = x[i][j]+y[i][j]; }}return result;} static int[][] matrixMinus(int[][] x,int[][] y)//矩阵减法{int matrixXRowLength = x.length;//获取矩阵的行长度int matrixXColumnLength = x[0].length;int matrixYRowLength = x.length;//获取矩阵的行长度int matrixYColumnLength = x[0].length;if(matrixXColumnLength!=matrixYColumnLength || matrixXRowLength!=matrixYRowLength){throw new RuntimeException("矩阵不同型,无法进行减法计算!");}int[][] result = new int[matrixXRowLength][matrixXColumnLength];for(int i=0;i<matrixXColumnLength;i++){for (int j = 0; j < matrixXColumnLength; j++) {result[i][j] = x[i][j]-y[i][j]; }}return result;}//Strassen二阶矩阵的乘法static int[][] twostrassenMatrixMultiply(int [][]x,int [][]y) //阶数为2的矩阵乘法 { int matrixXColumnLength = x[0].length;int matrixYRowLength = x.length;//获取矩阵的行长度if(matrixXColumnLength!=matrixYRowLength){throw new RuntimeException("matrixXColumnLength!=matrixYRowLength,无法进行乘法计算!");}int p1,p2,p3,p4,p5,p6,p7;//这些都是按照算法定义进行的int [][]result = new int[2][2];p1=(y[0][1] - y[1][1]) * x[0][0]; p2=y[1][1] * (x[0][0] + x[0][1]); p3=(x[1][0] + x[1][1]) * y[0][0]; p4=x[1][1] * (y[1][0] - y[0][0]); p5=(x[0][0] + x[1][1]) * (y[0][0]+y[1][1]); p6=(x[0][1] - x[1][1]) * (y[1][0]+y[1][1]); p7=(x[0][0] - x[1][0]) * (y[0][0]+y[0][1]); result[0][0] = p5 + p4 - p2 + p6; result[0][1] = p1 + p2; result[1][0] = p3 + p4;result[1][1] = p5 + p1 - p3 - p7;return result; } static void divide(int[][] a,int[][] a11,int[][] a12,int[][] a21,int[][] a22)//分解矩阵{ int matrixLength = a.length/2;for(int i=0;i<matrixLength;i++) for(int j=0;j<matrixLength;j++) {a11[i][j]=a[i][j];a12[i][j]=a[i][j+matrixLength]; a21[i][j]=a[i+matrixLength][j]; a22[i][j]=a[i+matrixLength][j+matrixLength]; } }static int[][] merge(int [][]a11,int [][]a12,int [][]a21,int [][]a22)//合并矩阵 { int n=a11.length;int [][] result = new int[2*n][2*n];for(int i=0;i<n;i++){for(int j=0;j<n;j++){result[i][j]=a11[i][j]; result[i][j+n]=a12[i][j]; result[i+n][j]=a21[i][j]; result[i+n][j+n]=a22[i][j]; }}return result;}static int[][] strassenMatrixMultiply(int [][]x,int [][]y) //矩阵乘法方法 { if(x.length==2) { return twostrassenMatrixMultiply(x,y);} else { int matrixLength = x.length/2;int[][] a11,a12,a21,a22;a11 = new int[matrixLength][matrixLength];a12 = new int[matrixLength][matrixLength];a21 = new int[matrixLength][matrixLength];a22 = new int[matrixLength][matrixLength];int[][] b11,b12,b21,b22;b11 = new int[matrixLength][matrixLength];b12 = new int[matrixLength][matrixLength];b21 = new int[matrixLength][matrixLength];b22 = new int[matrixLength][matrixLength];int[][] c11,c12,c21,c22,c; c11 = new int[matrixLength][matrixLength];c12 = new int[matrixLength][matrixLength];c21 = new int[matrixLength][matrixLength];c22 = new int[matrixLength][matrixLength];c = new int[2*matrixLength][2*matrixLength];int[][] m1,m2,m3,m4,m5,m6,m7;divide(x,a11,a12,a21,a22); //拆分A、B、C矩阵 divide(y,b11,b12,b21,b22); divide(c,c11,c12,c21,c22);m1=strassenMatrixMultiply(a11,matrixMinus(b12,b22)); m2=strassenMatrixMultiply(matrixPlus(a11,a12),b22);m3=strassenMatrixMultiply(matrixPlus(a21,a22),b11);m4=strassenMatrixMultiply(a22,matrixMinus(b21,b11)); m5=strassenMatrixMultiply(matrixPlus(a11,a22),matrixPlus(b11,b22)); m6=strassenMatrixMultiply(matrixMinus(a12,a22),matrixPlus(b21,b22)); m7=strassenMatrixMultiply(matrixMinus(a11,a21),matrixPlus(b11,b12)); c11=matrixPlus(matrixMinus(matrixPlus(m5,m4),m2),m6); c12=matrixPlus(m1,m2); c21=matrixPlus(m3,m4); c22=matrixMinus(matrixMinus(matrixPlus(m5,m1),m3),m7);c=merge(c11,c12,c21,c22); //合并C矩阵 return c; } }static int[][] forceMatrixMultiply(int [][]x,int [][]y){int matrixXRowLength = x.length;//获取矩阵的行长度int matrixXColumnLength = x[0].length;int matrixYRowLength = x.length;//获取矩阵的行长度int matrixYColumnLength = x[0].length;if(matrixXColumnLength!=matrixYRowLength){throw new RuntimeException("matrixXColumnLength!=matrixYRowLength,无法进行乘法计算!");}int [][] result = new int[matrixXRowLength][matrixYColumnLength];for(int i=0;i<matrixXRowLength;i++){for(int j=0;j<matrixYColumnLength;j++){result[i][j]=0;for(int k=0;k<matrixYRowLength;k++){result[i][j] = result[i][j]+x[i][k]*y[k][j];}}}return result;}}
0 0
- 《算法导论》学习心得(二)—— 矩阵乘法之Strassen算法
- 算法导论--------------Strassen矩阵乘法
- 算法导论之四矩阵乘法的Strassen算法
- 【算法导论】矩阵乘法strassen算法
- 算法导论-矩阵乘法-strassen算法
- 《算法导论》学习笔记之Chapter4.2矩阵乘法Strassen
- 矩阵乘法 之 strassen 算法
- Strassen算法之矩阵乘法
- 算法重拾之路——strassen矩阵乘法
- 矩阵乘法——strassen算法
- 分而治之——strassen矩阵乘法算法
- strassen算法(矩阵乘法)
- STRASSEN算法(矩阵乘法)
- strassen矩阵乘法算法
- 算法导论 第四章矩阵乘法的Strassen算法
- 矩阵乘法(Strassen算法/C++实现)
- 矩阵乘法(Strassen 算法实现)
- strassen算法优化矩阵乘法
- HDU 献给杭电五十周年校庆的礼物
- 个人收藏的IT网站
- eclipse怎样自动补全变量名
- vmware network adapter vmnet8 未识别网络解决方法
- 黑马程序员——关于输入一串数字然后加密
- 《算法导论》学习心得(二)—— 矩阵乘法之Strassen算法
- 【健康产业的商机】——新世纪的最大商机就是它,你怎么看?
- 单链表反转的递归实现(Reversing a Linked List in Java, recursively)
- TCP/IP详解卷2:实现 第一章 笔记四
- 什么是网络编程,Winsock ,SDK
- SDUTOJ 2482 二叉排序树
- 汇编语言的段地址与偏移地址的一点小心得:8086CPU是为了方便存储段地址才规定其一定是10H的倍数的
- 关于并发服务器的看法
- JNI的某些数组和字符串类型转换