strassen算法(矩阵乘法)

来源:互联网 发布:同治死因 知乎 编辑:程序博客网 时间:2024/04/24 13:10
#include <iostream>
#include <vector>
using namespace std;


class Matrix
{
public:
     Matrix(int);
     Matrix(const Matrix&);
     Matrix& operator=(const Matrix&);
     int operator()(int,int) const;
     int& operator()(int,int);
     Matrix& operator +=(const Matrix&);
     Matrix& operator -=(const Matrix&);
     Matrix& operator *=(int);
     Matrix& operator *=(const Matrix&);
     Matrix GetQuarter(int) const ;//得到1/4个矩阵
     Matrix& SetQuarter(const Matrix&,int);//设置1/4个矩阵
     int Side() const;//矩阵的行/列
     void Show() const;//打印矩阵
private:
     void Malloc(int);//设置一个仿二维数组的大小
     const int MN;//矩阵的行/列
     vector< vector<int> > Data;//矩阵的数据
};


Matrix operator +(const Matrix&,const Matrix&);//全局函数:矩阵相加
Matrix operator -(const Matrix&,const Matrix&);//全局函数:矩阵相减
Matrix operator *(int,const Matrix&);//全局函数:整数乘以矩阵,在本程序中没用到
Matrix operator *(const Matrix&,int);//全局函数:矩阵乘以整数,在本程序中没用到
Matrix operator *(const Matrix&,const Matrix&);//全局函数:矩阵相乘


void Matrix::Malloc(int mn)//设定一个仿二维数组,大小为mn*mn;

    Data.resize(mn);
    for(int i=0;i<mn;++i)
       Data[i].resize(mn);
}


Matrix::Matrix(int mn):MN(mn)
{
    Malloc(MN);
}


Matrix::Matrix(const Matrix& rhs):MN(rhs.MN)//拷贝构造
{   
    Malloc(MN);
    for(int i=0;i<MN;++i)
       for(int j=0;j<MN;++j)
       Data[i][j]=rhs.Data[i][j];   
}


Matrix& Matrix::operator=(const Matrix& rhs)//矩阵赋值
{
    if(MN!=rhs.MN) throw;
    for(int i=0;i<MN;++i)
       for(int j=0;j<MN;++j)
       Data[i][j]=rhs.Data[i][j];
    return *this;
}


int Matrix::operator()(int m,int n) const//得到矩阵的某个元素,坐标从0开始
{  return Data[m][n];
}


int& Matrix::operator()(int m,int n)//得到矩阵的某个元素,坐标从0开始,可以赋值
{  return Data[m][n];
}


Matrix& Matrix::operator +=(const Matrix& rhs)//类公开函数,A+=B;
{
   if(MN!=rhs.MN) throw;
   for(int i=0;i<MN;++i)
     for(int j=0;j<MN;++j)
     Data[i][j]+=rhs.Data[i][j];
   return *this;
}


Matrix& Matrix::operator -=(const Matrix& rhs)//类公开函数,A-=B;
{
   if(MN!=rhs.MN) throw;
   for(int i=0;i<MN;++i)
     for(int j=0;j<MN;++j)
     Data[i][j]-=rhs.Data[i][j];
   return *this;
}


Matrix& Matrix::operator *=(int Num)//类公开函数,A*=i;
{
   for(int i=0;i<MN;++i)
     for(int j=0;j<MN;++j)
     Data[i][j]*=Num;
   return *this;
}


Matrix Matrix::GetQuarter(int Pos) const//得到1/4个矩阵,Pos=0代表左上,Pos=1代表右上...
{
    int X,Y;
    switch(Pos%4)
    {
    case 0:X=Y=0;break;
    case 1:X=0;Y=MN/2;break;
    case 2:X=MN/2;Y=0;break;
    case 3:X=Y=MN/2;
    }
    Matrix T(MN/2);
    for(int i=0;i<T.MN;++i)
      for(int j=0;j<T.MN;++j)
      T.Data[i][j]=Data[i+X][j+Y];
    return T;
}


Matrix& Matrix::SetQuarter(const Matrix& rhs,int Pos)//把rhs的值拷贝为自身的1/4个矩阵,Pos含义同上
{
    int X,Y;
    switch(Pos%4)
    {
    case 0:X=Y=0;break;
    case 1:X=0;Y=MN/2;break;
    case 2:X=MN/2;Y=0;break;
    case 3:X=Y=MN/2;
    }
    for(int i=0;i<rhs.MN;++i)
      for(int j=0;j<rhs.MN;++j)
      Data[i+X][j+Y]=rhs.Data[i][j];
    return *this;
}


int Matrix::Side() const//得到行/列
{
   return MN;
}


Matrix& Matrix::operator *=(const Matrix& rhs)//类公开函数 A*=B;
{
  *this= *this * rhs;//调用全局函数:矩阵相乘
  return *this;
}


void Matrix::Show() const//打印矩阵
{  cout<<"Display Matrix:"<<endl;
   for(int i=0;i<MN;++i){
     for(int j=0;j<MN;++j)
       cout<<Data[i][j]<<' ';
       cout<<endl;
   }
}


Matrix operator +(const Matrix& rhs1,const Matrix& rhs2)//全局函数:矩阵相加
{
   Matrix T(rhs1);
   return T+=rhs2;//调用类公开函数+=
}


Matrix operator -(const Matrix& rhs1,const Matrix& rhs2)//全局函数:矩阵相减
{
   Matrix T(rhs1);
   return T-=rhs2;//调用类公开函数-=
}


Matrix operator *(int Num,const Matrix& rhs)//全局函数:整数乘以矩阵
{
   Matrix T(rhs);
   return T*=Num;//调用类公开函数*=
}


Matrix operator *(const Matrix& rhs,int Num)//全局函数:矩阵乘以整数
{
   Matrix T(rhs);
   return T*=Num;//调用类公开函数*=
}


Matrix operator *(const Matrix& rhs1,const Matrix& rhs2)//全局函数,矩阵相乘
{   if(rhs1.Side()!=rhs2.Side()) throw;
    if(rhs1.Side()==2){//行/列为2,按照常规方法计算
       Matrix T(2);
       T(0,0)=rhs1(0,0)*rhs2(0,0)+rhs1(0,1)*rhs2(1,0);//A11B11+A12B21;
       T(0,1)=rhs1(0,0)*rhs2(0,1)+rhs1(0,1)*rhs2(1,1);//A11B12+A12B22
       T(1,0)=rhs1(1,0)*rhs2(0,0)+rhs1(1,1)*rhs2(1,0);//A21B11+A22B21
       T(1,1)=rhs1(1,0)*rhs2(0,1)+rhs1(1,1)*rhs2(1,1);//A21B12+A22B22
       return T;
    };
    Matrix A11(rhs1.GetQuarter(0));//第一个矩阵的左上1/4矩阵
    Matrix A12(rhs1.GetQuarter(1));//第一个矩阵的右上1/4矩阵
    Matrix A21(rhs1.GetQuarter(2));//第一个矩阵的左下1/4矩阵
    Matrix A22(rhs1.GetQuarter(3));//第一个矩阵的右下1/4矩阵
    Matrix B11(rhs2.GetQuarter(0));//第二个矩阵的左上1/4矩阵
    Matrix B12(rhs2.GetQuarter(1));//第二个矩阵的右上1/4矩阵
    Matrix B21(rhs2.GetQuarter(2));//第二个矩阵的左下1/4矩阵
    Matrix B22(rhs2.GetQuarter(3));//第二个矩阵的右下1/4矩阵
    Matrix M1(A11*(B12-B22));//递归调用全局函数,矩阵相乘
    Matrix M2((A11+A12)*B22);//递归调用全局函数,矩阵相乘
    Matrix M3((A21+A22)*B11);//递归调用全局函数,矩阵相乘
    Matrix M4(A22*(B21-B11));//递归调用全局函数,矩阵相乘
    Matrix M5((A11+A22)*(B11+B22));//递归调用全局函数,矩阵相乘
    Matrix M6((A12-A22)*(B21+B22));//递归调用全局函数,矩阵相乘
    Matrix M7((A11-A21)*(B11+B12));//递归调用全局函数,矩阵相乘
    Matrix C11(M5+M4-M2+M6);//调用全局函数,矩阵相加/减
    Matrix C12(M1+M2);//调用全局函数,矩阵相加
    Matrix C21(M3+M4);//调用全局函数,矩阵相加
    Matrix C22(M5+M1-M3-M7);//调用全局函数,矩阵相加/减


    Matrix T(rhs1.Side());//返回的矩阵
    //设置C11-C22为T的四个小矩阵
    T.SetQuarter(C11,0).SetQuarter(C12,1).SetQuarter(C21,2).SetQuarter(C22,3);
    return T;
}


bool Is2Pow(int i)//判断i是否是2的n次方
{  if(i<2) return false;
   while(i>2){
      if(i%2) return false;
      i/=2;
   }
   return i==2 ? true:false;
}
int main()
{  int M;
   cout<<"Input two matrixes[M*M],and culculate multiply with Strassen!"<<endl;
   cout<<"M=";
   cin>>M;
   if(Is2Pow(M)==false){
      cout<<"Error:M should equal 2^n";
      return 1;
   }
   cout<<"Input Matrix A["<<M<<"*"<<M<<"]:"<<endl;
   Matrix A(M);
   for(int i=0;i<M;++i)
     for(int j=0;j<M;++j)
     cin>>A(i,j);
   cout<<"Input Matrix B["<<M<<"*"<<M<<"]:"<<endl;
   Matrix B(M);
   for(int i=0;i<M;++i)
     for(int j=0;j<M;++j)
     cin>>B(i,j);
   Matrix AB(A*B);
   cout<<"A*B"<<endl;
   AB.Show();
   Matrix BA(B*A);
   cout<<"B*A"<<endl;
   BA.Show();
   return 0;
}