Strassen方法完成N*N矩阵的相乘

来源:互联网 发布:威少刷数据gif 编辑:程序博客网 时间:2024/06/06 20:42

     因为没有伪代码,所以这个程序自己写了好久,而且还纠结书上创建12个新矩阵,不用复制元素就完成分解,利用下标计算,,我最后还是没想出来,。虽然这个算法的渐进复杂度为θ(n^lg7)低于普通迭代的θ(n^3),但是需要很多其他的代价,会让运行时间增加常量倍,所以其实实用性不高,当然这个方法的发现是具有重大意义的。

   然后书上讲的是当n降为1的时候进行简单的乘法计算,我发现如果在n=2的时候就进行计算可以加快运行时间,,所以可以认识在,在算法的运用中,或者应该结合两种算法,在一个平衡点上获得最大性能。

下面是代码:虽然看着很长,有一半都在申请数组空间和删除动态数组空间,

#include<iostream>#include"MyTimer.h"#include<fstream>using namespace std;void sub(int **a,int **b,int n,int **c){for(int i=0;i!=n;++i){    for(int j=0;j!=n;++j)    c[i][j]=a[i][j]-b[i][j];}}void add(int **a,int **b,int n,int **c){  for(int i=0;i!=n;++i){    for(int j=0;j!=n;++j){    c[i][j]=a[i][j]+b[i][j];}}}void strassen_multiply(int **a,int **b,int **c,int n){ //Strassen理论上时间复杂度为θ(n^lg7)但是实际测试中由于需要很多额外的开销,速度明显比普通迭代法慢    if(n==1){  //n=1直接运算, 其实可以n=2的时候就直接运算,可以快很多,        c[0][0]=a[0][0]*b[0][0];    }    else{ //按照strassen的方法一步步算,,int mid=n/2;//书上说不用创建12个额外数组,,运用下标计算, 这个我不会,网上也没找到int **a11=new int*[mid]; int **a12=new int*[mid]; int **a21=new int*[mid]; int **a22=new int*[mid];int **b11=new int*[mid]; int **b12=new int*[mid]; int **b21=new int*[mid]; int **b22=new int*[mid];int **c11=new int*[mid]; int **c12=new int*[mid]; int **c21=new int*[mid]; int **c22=new int*[mid];int **s1=new int*[mid]; int **s2=new int*[mid]; int **s3=new int*[mid]; int **s4=new int*[mid]; int **s5=new int*[mid];int **s6=new int*[mid]; int **s7=new int*[mid]; int **s8=new int*[mid]; int **s9=new int*[mid]; int **s10=new int*[mid];int **p1=new int*[mid]; int **p2=new int*[mid]; int **p3=new int*[mid]; int **p4=new int*[mid];int **p5=new int*[mid]; int **p6=new int*[mid]; int **p7=new int*[mid];for(int i=0;i!=mid;++i){     a11[i]=new int [mid];   a12[i]=new int [mid];   a21[i]=new int [mid];   a22[i]=new int [mid];     b11[i]=new int [mid];   b12[i]=new int [mid];   b21[i]=new int [mid];   b22[i]=new int [mid];     c11[i]=new int [mid];   c12[i]=new int [mid];   c21[i]=new int [mid];   c22[i]=new int [mid];     s1[i]=new int[mid]; s2[i]=new int[mid]; s3[i]=new int[mid]; s4[i]=new int[mid]; s5[i]=new int[mid];     s6[i]=new int[mid]; s7[i]=new int[mid]; s8[i]=new int[mid]; s9[i]=new int[mid]; s10[i]=new int[mid];     p1[i]=new int [mid]; p2[i]=new int [mid]; p3[i]=new int [mid]; p4[i]=new int [mid];     p5[i]=new int [mid]; p6[i]=new int [mid]; p7[i]=new int [mid];}for(int i=0;i!=mid;++i){for(int j=0;j!=mid;++j){//分解a b成8个n/2 × n/2的小矩阵     θ(n^2)    a11[i][j]=a[i][j];    a12[i][j]=a[i][j+mid];    a21[i][j]=a[i+mid][j];    a22[i][j]=a[i+mid][j+mid];    b11[i][j]=b[i][j];    b12[i][j]=b[i][j+mid];    b21[i][j]=b[i+mid][j];    b22[i][j]=b[i+mid][j+mid];}}//计算s1 到 s10  十次n/2 × n/2的矩阵加减法 所以 θ(n^2)sub(b12,b22,mid,s1);    add(a11,a12,mid,s2);    add(a21,a22,mid,s3);sub(b21,b11,mid,s4);    add(a11,a22,mid,s5);    add(b11,b22,mid,s6);sub(a12,a22,mid,s7);    add(b21,b22,mid,s8);     sub(a11,a21,mid,s9);add(b11,b12,mid,s10);//计算p1 到p7 递归计算7次 n/2 × n/2的矩阵乘法  7T(n/2)strassen_multiply(a11,s1,p1,mid);strassen_multiply(s2,b22,p2,mid);strassen_multiply(s3,b11,p3,mid);strassen_multiply(a22,s4,p4,mid);strassen_multiply(s5,s6,p5,mid);strassen_multiply(s7,s8,p6,mid);strassen_multiply(s9,s10,p7,mid);//计算c11 c12 c21 c22 8次n/2 * n/2 矩阵的加减法 θ(n^2)add(p5,p4,mid,c11); add(c11,p6,mid,c11); sub(c11,p2,mid,c11);add(p1,p2,mid,c12);add(p3,p4,mid,c21);add(p5,p1,mid,c22);sub(c22,p3,mid,c22);sub(c22,p7,mid,c22);for(int i=0;i!=mid;++i){ //把值赋回c  其实这里也需要θ(n^2)书上没讲到         for(int j=0;j!=mid;++j){    c[i][j]=c11[i][j];    c[i][j+mid]=c12[i][j];    c[i+mid][j]=c21[i][j];    c[i+mid][j+mid]=c22[i][j];}}    for(int i=0;i!=mid;++i){        delete []a11[i];   delete []a12[i];   delete []a21[i];   delete []a22[i];        delete []b11[i];   delete []b12[i];   delete []b21[i];   delete []b22[i];        delete []c11[i];    delete[]c12[i];   delete []c21[i];   delete []c22[i];     delete []p1[i];        delete []p2[i];    delete []p3[i];  delete []p4[i];   delete []p5[i];  delete []p6[i]; delete []p7[i];    delete []s1[i];    delete []s2[i];    delete []s3[i];     delete []s4[i];    delete []s5[i];     delete []s6[i];       delete []s7[i];        delete []s8[i];        delete []s9[i];  delete []s10[i];    }    delete []a11;  delete []a12;  delete[]a21;  delete[] a22; delete []b11;  delete []b12;  delete []b21;  delete []b22; delete []c11;  delete []c12;  delete []c21;  delete[] c22; delete []s1;   delete []s2;   delete []s3;   delete []s4;  delete []s5;  delete []s6;   delete []s7;  delete []s8;  delete []s9;  delete []s10;delete[]p1; delete[]p1;delete[]p1;delete[]p1;delete[]p1;delete[]p1;delete[]p1;    }}int main(){ifstream infile("rebuf.txt");cin.rdbuf(infile.rdbuf());int n=8;int **shu1=new int*[n];int **shu2=new int*[n];int **result=new int*[n];for(int i=0;i!=n;++i){    shu1[i]=new int[n];    shu2[i]=new int[n];    result[i]=new int[n];}for(int i=0;i!=n;++i){        for(int j=0;j!=n;++j){    cin>>shu1[i][j];    }}for(int i=0;i!=n;++i){        for(int j=0;j!=n;++j){    cin>>shu2[i][j];    }}MyTimer mt;mt.Start(); strassen_multiply(shu1,shu2,result,n); mt.End(); cout<<mt.costTime<<"us"<<endl;/*for(int i=0;i!=n;++i){        for(int j=0;j!=n;++j){    cout<<result[i][j]<<" ";    }cout<<endl;}*/for(int i=0;i!=n;++i){   delete []shu1[i];   delete []shu2[i];   delete []result[i];}delete []shu1;delete []shu2;delete []result;return 0;}


0 0
原创粉丝点击