算法导论 第四章矩阵乘法的Strassen算法

来源:互联网 发布:怎么获取网站数据库 编辑:程序博客网 时间:2024/04/30 22:41

Strassen算法的核心思想是令递归树不那么茂盛一点,只递归进行7次而不是8次n/2*n/2矩阵的乘法。减少一次矩阵乘法带来的代价。

它包含4个步骤

1.将矩阵 A B C分解成n/2*n/2的子矩阵

2.创建10个n/2*n/2矩阵s1到s10 用来保存步骤1中的子矩阵的差 和 和 。

3.创建7个n/2*n/2矩阵p1到p7,用步骤1的子矩阵和2中的10个矩阵,递归的计算7个的积。

4.对3中的pi矩阵进行加减运算 C11=p5+p4-p2+p6   C12=p1+p2  C21=p3+p4  C22=p5+p1-p3-p7

5.再把C11 C12 C21 C22 赋回给C。

// 矩阵乘法之Strassen算法.cpp : 定义控制台应用程序的入口点。//#include "stdafx.h"#include <iostream>using namespace std;const int N=4;void input(int a[][N],int n){for(int i=0;i<n;i++){for(int j=0;j<n;j++){cin>>a[i][j];}}return ;}void output(int a[][N],int n){for(int i=0;i<n;i++){for(int j=0;j<n;j++){cout<<a[i][j]<<" ";}cout<<endl;}return ;}void MATRIX_MULTIPLY(int a[][N],int b[][N],int c[][N])//当N为2时直接计算 按普通方法{for(int i=0;i<2;i++){for(int j=0;j<2;j++){c[i][j]=0;for(int k=0;k<2;k++){c[i][j]=c[i][j]+a[i][k]*b[k][j];}}}return ;}void sum(int a[][N],int b[][N],int c[][N],int n){for(int i=0;i<n;i++){for(int j=0;j<n;j++){   c[i][j]=a[i][j]+b[i][j];}}return ;}void sub(int a[][N],int b[][N],int c[][N],int n){for(int i=0;i<n;i++){for(int j=0;j<n;j++){   c[i][j]=a[i][j]-b[i][j];}}return ;}void Strassen(int n,int a[][N],int b[][N],int c[][N]){    int a11[N][N],a12[N][N],a21[N][N],a22[N][N];int b11[N][N],b12[N][N],b21[N][N],b22[N][N];int c11[N][N],c12[N][N],c21[N][N],c22[N][N];int s1[N][N],s2[N][N],s3[N][N],s4[N][N],s5[N][N],s6[N][N],s7[N][N],s8[N][N],s9[N][N],s10[N][N];int p1[N][N],p2[N][N],p3[N][N],p4[N][N],p5[N][N],p6[N][N],p7[N][N];int MM1[N][N],MM2[N][N];if(n==2)MATRIX_MULTIPLY(a,b,c);else{for(int i=0;i<n/2;i++)//第一步把a b c矩阵分解为N*N的子矩阵{for(int j=0;j<n/2;j++){a11[i][j]=a[i][j];a12[i][j]=a[i][j+n/2];a21[i][j]=a[i+n/2][j];a22[i][j]=a[i+n/2][j+n/2];b11[i][j]=b[i][j];b12[i][j]=b[i][j+n/2];b21[i][j]=b[i+n/2][j];b22[i][j]=b[i+n/2][j+n/2];}}//a b分解完成   //第二步每个矩阵保存1中创建的两个矩阵的和或差sub(b12,b22,s1,n/2);sum(a11,a12,s2,n/2);sum(a21,a22,s3,n/2);sub(b21,b11,s4,n/2);sum(a11,a22,s5,n/2);sum(b11,b22,s6,n/2);sub(a12,a22,s7,n/2);sum(b21,b22,s8,n/2);sub(a11,a21,s9,n/2);sum(b11,b12,s10,n/2);//第三步利用1中建立的子矩阵和2中建立的10个矩阵递归的计算7个矩阵的积每个矩阵pi都是N      Strassen(n/2,a11,s1,p1);Strassen(n/2,s2,b22,p2);Strassen(n/2,s3,b11,p3);Strassen(n/2,a22,s4,p4);Strassen(n/2,s5,s6,p5);Strassen(n/2,s7,s8,p6);Strassen(n/2,s9,s10,p7);//对3中创建的pi矩阵进行加减法运算,并计算出4个n/2*n/2的子矩阵sum(p5,p4,MM1,N/2);sub(p2,p6,MM2,N/2);sub(MM1,MM2,c11,N/2);//c11=p5+p4-p2+p6sum(p1,p2,c12,N/2);//c12=p1+p2sum(p3,p4,c21,N/2);//c21=p3+p4sum(p5,p1,MM1,N/2);//c21=p3+p4sum(p3,p7,MM2,N/2);sub(MM1,MM2,c22,N/2);//c11=p5+P1-P3-P7for(int i=0;i<n/2;i++){for(int j=0;j<n/2;j++){c[i][j]=c11[i][j];c[i][j+n/2]=c12[i][j];c[i+n/2][j]=c21[i][j];c[i+n/2][j+n/2]=c22[i][j];}}}return ;}int main (){int a[N][N];int b[N][N];int c[N][N];cout<<"输入矩阵A"<<endl;input(a,N);cout<<"输入矩阵B"<<endl;input(b,N);Strassen(N,a,b,c);cout<<"相乘之后的矩阵为\n";output(c,N);return 1;}


0 0
原创粉丝点击