Strassen矩阵乘法算法-c++实现

来源:互联网 发布:网络用语 盐 编辑:程序博客网 时间:2024/03/29 17:42

基本完全参考以下文章实现,不过是看了一遍之后自己为了加深理解手写的。

http://www.mamicode.com/info-detail-673908.html

#include "stdafx.h"#include <stdio.h>#include <iostream>#include <windows.h>#include <ctime>using namespace std;template <typename T>class Strassen{public:void ADD(T**  MatrixA,T** MatrixB,T** MatrixResult,int size);void SUB(T**  MatrixA,T** MatrixB,T** MatrixResult,int size);void NormalMul(T**  MatrixA,T** MatrixB,T** MatrixResult,int size);void StrassenMul(T**  MatrixA,T** MatrixB,T** MatrixResult,int size);void FillMatrix(T**  MatrixA,T** MatrixB,int size);//给A、B矩阵赋初值int   GetMatrixSum(T** Matrix,int size);//用来计算矩阵各个元素的和,如果两种算法得出的矩阵的和相等则认为算法正确。};template <typename T>void Strassen<T>::ADD(T**  MatrixA,T** MatrixB,T** MatrixResult,int size){for(int i=0;i<size;i++){for(int j=0;j<size;j++){MatrixResult[i][j] = MatrixA[i][j] + MatrixB[i][j];}}}template <typename T>void Strassen<T>::SUB(T**  MatrixA,T** MatrixB,T** MatrixResult,int size){for(int i=0;i<size;i++){for(int j=0;j<size;j++){MatrixResult[i][j] = MatrixA[i][j] - MatrixB[i][j];}}}template <typename T>void Strassen<T>::NormalMul(T**  MatrixA,T** MatrixB,T** MatrixResult,int size){for(int i=0;i<size;i++){for(int j=0;j<size;j++){MatrixResult[i][j] = 0;for(int k=0;k<size;k++)MatrixResult[i][j] += MatrixA[i][k]*MatrixB[k][j];}}}template <typename T>void Strassen<T>::FillMatrix(T**  MatrixA,T** MatrixB,int size)//给A、B矩阵赋初值{for(int i=0;i<size;i++){for(int j=0;j<size;j++){MatrixA[i][j] = MatrixB[i][j] = rand()%5; }}}template <typename T>void Strassen<T>::StrassenMul(T**  MatrixA,T** MatrixB,T** MatrixResult,int size){// if ( size <= 64 )    //分治门槛,小于这个值时不再进行递归计算,而是采用常规矩阵计算方法// {// NormalMul(MatrixA,MatrixB,MatrixResult,size);// }if(size == 1){MatrixResult[0][0] = MatrixA[0][0] * MatrixB[0][0];}else{int half_size = size/2;T** A11;T** A12;T** A21;T** A22;T** B11;T** B12;T** B21;T** B22;T** C11;T** C12;T** C21;T** C22;T** M1;T** M2;T** M3;T** M4;T** M5;T** M6;T** M7;T** MatrixTemp1;T** MatrixTemp2;A11 = new int*[half_size];A12 = new int*[half_size];A21 = new int*[half_size];A22 = new int*[half_size];B11 = new int*[half_size];B12 = new int*[half_size];B21 = new int*[half_size];B22 = new int*[half_size];C11 = new int*[half_size];C12 = new int*[half_size];C21 = new int*[half_size];C22 = new int*[half_size];M1 = new int*[half_size];M2 = new int*[half_size];M3 = new int*[half_size];M4 = new int*[half_size];M5 = new int*[half_size];M6 = new int*[half_size];M7 = new int*[half_size];MatrixTemp1 = new int*[half_size];MatrixTemp2 = new int*[half_size];for(int i=0;i<half_size;i++){A11[i] = new int[half_size];A12[i] = new int[half_size];A21[i] = new int[half_size];A22[i] = new int[half_size];B11[i] = new int[half_size];B12[i] = new int[half_size];B21[i] = new int[half_size];B22[i] = new int[half_size];C11[i] = new int[half_size];C12[i] = new int[half_size];C21[i] = new int[half_size];C22[i] = new int[half_size];M1[i] = new int[half_size];M2[i] = new int[half_size];M3[i] = new int[half_size];M4[i] = new int[half_size];M5[i] = new int[half_size];M6[i] = new int[half_size];M7[i] = new int[half_size];MatrixTemp1[i] = new int[half_size];MatrixTemp2[i] = new int[half_size];}//赋值for(int i=0;i<half_size;i++){for(int j=0;j<half_size;j++){A11[i][j] = MatrixA[i][j];A12[i][j] = MatrixA[i][j+half_size];A21[i][j] = MatrixA[i+half_size][j];A22[i][j] = MatrixA[i+half_size][j+half_size];B11[i][j] = MatrixB[i][j];B12[i][j] = MatrixB[i][j+half_size];B21[i][j] = MatrixB[i+half_size][j];B22[i][j] = MatrixB[i+half_size][j+half_size];}}//calculate M1ADD(A11,A22,MatrixTemp1,half_size);ADD(B11,B22,MatrixTemp2,half_size);StrassenMul(MatrixTemp1,MatrixTemp2,M1,half_size);//calculate M2ADD(A21,A22,MatrixTemp1,half_size);StrassenMul(MatrixTemp1,B11,M2,half_size);//calculate M3SUB(B12,B22,MatrixTemp1,half_size);StrassenMul(A11,MatrixTemp1,M3,half_size);//calculate M4SUB(B21,B11,MatrixTemp1,half_size);StrassenMul(A22,MatrixTemp1,M4,half_size);//calculate M5ADD(A11,A12,MatrixTemp1,half_size);StrassenMul(MatrixTemp1,B22,M5,half_size);//calculate M6SUB(A21,A11,MatrixTemp1,half_size);ADD(B11,B12,MatrixTemp2,half_size);StrassenMul(MatrixTemp1,MatrixTemp2,M6,half_size);//calculate M7SUB(A12,A22,MatrixTemp1,half_size);ADD(B21,B22,MatrixTemp2,half_size);StrassenMul(MatrixTemp1,MatrixTemp2,M7,half_size);//C11ADD(M1,M4,C11,half_size);SUB(C11,M5,C11,half_size);ADD(C11,M7,C11,half_size);//C12ADD(M3,M5,C12,half_size);//C21ADD(M2,M4,C21,half_size);//C22SUB(M1,M2,C22,half_size);ADD(C22,M3,C22,half_size);ADD(C22,M6,C22,half_size);//赋值for(int i=0;i<half_size;i++){for(int j=0;j<half_size;j++){MatrixResult[i][j] = C11[i][j];MatrixResult[i][j+half_size]    = C12[i][j];MatrixResult[i+half_size][j]    = C21[i][j];MatrixResult[i+half_size][j+half_size]      = C22[i][j];}}//释放申请的内存for(int i=0;i<half_size;i++){delete[] A11[i];delete[] A12[i];delete[] A21[i];delete[] A22[i];delete[] B11[i];delete[] B12[i];delete[] B21[i];delete[] B22[i];delete[] C11[i];delete[] C12[i];delete[] C21[i];delete[] C22[i];delete[] M1[i];delete[] M2[i];delete[] M3[i];delete[] M4[i];delete[] M5[i];delete[] M6[i];delete[] M7[i];delete[] MatrixTemp1[i];delete[] MatrixTemp2[i];}delete[] A11;delete[] A12;delete[] A21;delete[] A22;delete[] B11;delete[] B12;delete[] B21;delete[] B22;delete[] C11;delete[] C12;delete[] C21;delete[] C22;delete[] M1;delete[] M2;delete[] M3;delete[] M4;delete[] M5;delete[] M6;delete[] M7;delete[] MatrixTemp1;delete[] MatrixTemp2;}}template <typename T>int   Strassen<T>::GetMatrixSum(T** Matrix,int size){int sum = 0;for(int i=0;i<size;i++){for(int j=0;j<size;j++){sum += Matrix[i][j];}}return sum;}int main(){long startTime_normal,endTime_normal;long startTime_strasse,endTime_strassen;//srand(time(0));Strassen<int> stra;int N;cout<<"please input the size of Matrix,and the size must be the power of 2:"<<endl;cin>>N;int** Matrix1 = new int*[N];int** Matrix2 = new int*[N];int** Matrix3 = new int*[N];for(int i=0;i<N;i++){Matrix1[i] = new int[N];Matrix2[i] = new int[N];Matrix3[i] = new int[N];}stra.FillMatrix(Matrix1,Matrix2,N);cout<<"朴素算法开始时间:"<<(startTime_normal = clock())<<endl;stra.NormalMul(Matrix1,Matrix2,Matrix3,N);cout<<"朴素算法结束时间:"<<(endTime_normal = clock())<<endl;cout<<"总时间:"<<endTime_normal-startTime_normal<<endl;cout<<"sum = "<<stra.GetMatrixSum(Matrix3,N)<<';'<<endl;cout<<"Strassen算法开始时间:"<<(startTime_strasse= clock())<<endl;stra.StrassenMul(Matrix1,Matrix2,Matrix3,N);cout<<"Strassen算法结束时间:"<<(endTime_strassen = clock())<<endl;cout<<"总时间:"<<endTime_strassen-startTime_strasse<<endl;cout<<"sum = "<<stra.GetMatrixSum(Matrix3,N)<<';'<<endl;}


0 0
原创粉丝点击