Strassen矩阵算法分析及其C++实现 递归分治法(转)
来源:互联网 发布:h5与后端数据库 编辑:程序博客网 时间:2024/05/22 14:42
对于矩阵乘法 C = A × B,通常的做法是将矩阵进行分块相乘,如下图所示:
从上图可以看出这种分块相乘总共用了8次乘法,当然对于子矩阵相乘(如A0×B0),还可以继续递归使用分块相乘。对于中小矩阵来说,很适合使用这种分块乘法,但是对于大矩阵来说,递归的次数较多,如果能减少每次分块乘法的次数,那么性能将可以得到很好的提高。
Strassen矩阵乘法就是采用了一个简单的运算技巧,将上面的8次矩阵相乘变成了7次乘法,看别小看这减少的1次乘法,因为每递归1次,性能就提高了1/8,比如对于1024*1024的矩阵,第1次先分解成7次512*512的矩阵相乘,对于512*512的矩阵,又可以继续递归分解成256*256的矩阵相乘,…,一直递归下去,假设分解到64*64的矩阵大小后就不再递归,那么所花的时间将是分块矩阵乘法的(7/8) * (7/8) * (7/8) * (7/8) = 0.586倍,提高了快接近一倍。当然这是理论上的值,因为实际上strassen乘法增加了其他运算开销,实际性能会略低一点。
由上可见,Strassen矩阵乘法是通过递归实现的,它将一般情况下二阶矩阵乘法(可扩展到n阶,但Strassen矩阵乘法要求n是2的幂)所需的8次乘法降低为7次,其C++实现代码如下:
下面就是Strassen矩阵乘法的实现方法,
M1 = (A0 + A3) × (B0 + B3)
M2 = (A2 + A3) × B0
M3 = A0 × (B1 - B3)M4 = A3 × (B2 - B0)M5 = (A0 + A1) × B3M6 = (A2 - A0) × (B0 + B1)M7 = (A1 - A3) × (B2 + B3)C0 = M1 + M4 - M5 + M7C1 = M3 + M5C2 = M2 + M4C3 = M1 - M2 + M3 + M6
在求解M1,M2,M3,M4,M5,M6,M7时需要使用7次矩阵乘法,其他都是矩阵加法和减法。
下面看看Strassen矩阵乘法的串行实现伪代码:
Serial_StrassenMultiply(A, B, C)
{
T1 = A0 + A3;T2 = B0 + B3;StrassenMultiply(T1, T2, M1);T1 = A2 + A3;StrassenMultiply(T1, B0, M2);T1 = (B1 - B3);StrassenMultiply (A0, T1, M3);T1 = B2 - B0;StrassenMultiply(A3, T1, M4);
T1 = A0 + A1;
StrassenMultiply(T1, B3, M5);
T1 = A2 – A0;T2 = B0 + B1;StrassenMultiply(T1, T2, M6);T1 = A1 – A3;T2 = B2 + B3;StrassenMultiply(T1, T2, M7);C0 = M1 + M4 - M5 + M7C1 = M3 + M5C2 = M2 + M4C3 = M1 - M2 + M3 + M6
}
#include <iostream> using namespace std; const int N = 6; //Define the size of the Matrix template<typename T> void Strassen(int n, T A[][N], T B[][N], T C[][N]); template<typename T> void input(int n, T p[][N]); template<typename T> void output(int n, T C[][N]); int main() { //Define three Matrices int A[N][N],B[N][N],C[N][N]; //对A和B矩阵赋值,随便赋值都可以,测试用 for(int i=0; i<N; i++) { for(int j=0; j<N; j++) { A[i][j] = i * j; B[i][j] = i * j; } } //调用Strassen方法实现C=A*B Strassen(N, A, B, C); //输出矩阵C中值 output(N, C); system("pause"); return 0; } template<typename T> void input(int n, T p[][N]) { for(int i=0; i<n; i++) { cout<<"Please Input Line "<<i+1<<endl; for(int j=0; j<n; j++) { cin>>p[i][j]; } } } template<typename T> void output(int n, T C[][N]) { cout<<"The Output Matrix is :"<<endl; for(int i=0; i<n; i++) { for(int j=0; j<n; j++) { cout<<C[i][j]<<""<<endl; } } } template<typename T> void Matrix_Multiply(T A[][N], T B[][N], T C[][N]) { //Calculating A*B->C for(int i=0; i<2; i++) { for(int j=0; j<2; j++) { C[i][j] = 0; for(int t=0; t<2; t++) { C[i][j] = C[i][j] + A[i][t]*B[t][j]; } } } } template <typename T> void Matrix_Add(int n, T X[][N], T Y[][N], T Z[][N]) { for(int i=0; i<n; i++) { for(int j=0; j<n; j++) { Z[i][j] = X[i][j] + Y[i][j]; } } } template <typename T> void Matrix_Sub(int n, T X[][N], T Y[][N], T Z[][N]) { for(int i=0; i<n; i++) { for(int j=0; j<n; j++) { Z[i][j] = X[i][j] - Y[i][j]; } } } template <typename T> void Strassen(int n, T A[][N], T B[][N], T C[][N]) { T A11[N][N], A12[N][N], A21[N][N], A22[N][N]; T B11[N][N], B12[N][N], B21[N][N], B22[N][N]; T C11[N][N], C12[N][N], C21[N][N], C22[N][N]; T M1[N][N], M2[N][N], M3[N][N], M4[N][N], M5[N][N], M6[N][N], M7[N][N]; T AA[N][N], BB[N][N]; if(n == 2) { //2-order Matrix_Multiply(A, B, C); } else { //将矩阵A和B分成阶数相同的四个子矩阵,即分治思想。 for(int i=0; i<n/2; i++) { for(int j=0; j<n/2; j++) { A11[i][j] = A[i][j]; A12[i][j] = A[i][j+n/2]; A21[i][j] = A[i+n/2][j]; A22[i][j] = A[i+n/2][j+n/2]; B11[i][j] = B[i][j]; B12[i][j] = B[i][j+n/2]; B21[i][j] = B[i+n/2][j]; B22[i][j] = B[i+n/2][j+n/2]; } } //Calculate M1 = (A0 + A3) × (B0 + B3) Matrix_Add(n/2, A11, A22, AA); Matrix_Add(n/2, B11, B22, BB); Strassen(n/2, AA, BB, M1); //Calculate M2 = (A2 + A3) × B0 Matrix_Add(n/2, A21, A22, AA); Strassen(n/2, AA, B11, M2); //Calculate M3 = A0 × (B1 - B3) Matrix_Sub(n/2, B12, B22, BB); Strassen(n/2, A11, BB, M3); //Calculate M4 = A3 × (B2 - B0) Matrix_Sub(n/2, B21, B11, BB); Strassen(n/2, A22, BB, M4); //Calculate M5 = (A0 + A1) × B3 Matrix_Add(n/2, A11, A12, AA); Strassen(n/2, AA, B22, M5); //Calculate M6 = (A2 - A0) × (B0 + B1) Matrix_Sub(n/2, A21, A11, AA); Matrix_Add(n/2, B11, B12, BB); Strassen(n/2, AA, BB, M6); //Calculate M7 = (A1 - A3) × (B2 + B3) Matrix_Sub(n/2, A12, A22, AA); Matrix_Add(n/2, B21, B22, BB); Strassen(n/2, AA, BB, M7); //Calculate C0 = M1 + M4 - M5 + M7 Matrix_Add(n/2, M1, M4, AA); Matrix_Sub(n/2, M7, M5, BB); Matrix_Add(n/2, AA, BB, C11); //Calculate C1 = M3 + M5 Matrix_Add(n/2, M3, M5, C12); //Calculate C2 = M2 + M4 Matrix_Add(n/2, M2, M4, C21); //Calculate C3 = M1 - M2 + M3 + M6 Matrix_Sub(n/2, M1, M2, AA); Matrix_Add(n/2, M3, M6, BB); Matrix_Add(n/2, AA, BB, C22); //Set the result to C[][N] for(int i=0; i<n/2; i++) { for(int j=0; j<n/2; j++) { C[i][j] = C11[i][j]; C[i][j+n/2] = C12[i][j]; C[i+n/2][j] = C21[i][j]; C[i+n/2][j+n/2] = C22[i][j]; } } } }
- Strassen矩阵算法分析及其C++实现 递归分治法(转)
- Strassen矩阵算法分析及其C++实现 递归分治法(转)
- 算法导论C语言实现: 分治策略 -- 矩阵乘法的Strassen算法
- Strassen矩阵乘法(分治法)
- strassen矩阵乘法,分治实现
- 基于Strassen算法采用分治的矩阵乘法cpp实现
- 递归与分治策略:Strassen矩阵乘法
- Strassen矩阵乘法 分治与递归
- 第四章 4.2矩阵乘法的Strassen算法(分治)
- Strassen’s 矩阵乘法—分治法实现
- 矩阵乘法(Strassen算法/C++实现)
- 矩阵乘法(Strassen 算法实现)
- Strassen矩阵算法java实现
- Strassen矩阵乘法算法实现
- Strassen矩阵算法的实现
- Strassen矩阵乘法(分治法续)
- Strassen矩阵乘法(分治法续)
- 分治-Strassen矩阵乘法
- 九九乘法表
- CPM CPT CPC CPA CPS这些在营销广告的意义
- Linux(centos6.4)下ArcSDE10.0安装文档
- 在Linux下查看环境变量
- BZOJ2104【线段树】
- Strassen矩阵算法分析及其C++实现 递归分治法(转)
- 华硕升级bios的问题
- 单源最短路径的Dijkstra 算法
- map知识整理
- 网络编程面试题
- Qt5读写Access 数据库
- DbUtils源码分析系列(一)
- POJ 2406 Power Strings
- 动态数组的使用之char *res=new char(strlen(src)+1)