【算法导论】矩阵乘法strassen算法
来源:互联网 发布:图片批量下载软件 编辑:程序博客网 时间:2024/04/18 05:09
矩阵运算在做科学运算时是必不可少的,如果采用matlab来计算,这倒也容易。但是如果是自己写c或者c++代码,一般而言,需要做三次循环,其时间复杂度就是O(n^3)。
上图给出了我们一般会采用的方法,就是对应元素相乘和相加。如果把C=A*B进行分解,可以看出,这里需要进行8次的乘法运算:
分别是:
r = a * e + b * g ;
s = a * f + b * h ;
t = c * e + d * g;
u = c * f + d * h;
本文介绍的算法就是strassen提出的,可以将8次乘法降为7次乘法,虽然只是一次乘法,但是其实一次算法耗时要比加减法多很多。处理的方法是写成:
p1 = a * ( f - h )
p2 = ( a + b ) * h
p3 = ( c +d ) * e
p4 = d * ( g - e )
p5 = ( a + d ) * ( e + h )
p6 = ( b - d ) * ( g + h )
p7 = ( a - c ) * ( e + f )
那么只需要计算p1,p2,p3,p4,p5,p6,p7,然后
r = p5 + p4 + p6 - p2
s = p1 + p2
t = p3 + p4
u = p5 + p1 - p3 - p7
这样,八次的乘法就变成了7次乘法和一次加减法,最终达到降低复杂度为O( n^lg7 ) ~= O( n^2.81 );
c++代码如下:
/*Strassen Algorithm Implementation in C++Coded By: Seyyed Hossein Hasan Pour MatiKolaee in May 5 2010 .Mazandaran University of Science and Technology,Babol,Mazandaran,Iran--------------------------------------------Email : Master.huricane@gmail.comYM : Deathmaster_nemessis@yahoo.comUpdated may 09 2010.*/#include <iostream>#include <cstdlib>#include <iomanip>#include <ctime>#include <windows.h>using namespace std;int Strassen(int n, int** MatrixA, int ** MatrixB, int ** MatrixC);//Multiplies Two Matrices recrusively.int ADD(int** MatrixA, int** MatrixB, int** MatrixResult, int length );//Adds two Matrices, and places the result in another Matrixint SUB(int** MatrixA, int** MatrixB, int** MatrixResult, int length );//subtracts two Matrices , and places the result in another Matrixint MUL(int** MatrixA, int** MatrixB, int** MatrixResult, int length );//Multiplies two matrices in conventional way.void FillMatrix( int** matrix1, int** matrix2, int length);//Fills Matrices with random numbers.void PrintMatrix( int **MatrixA, int MatrixSize );//prints the Matrix content.int main(){ int MatrixSize = 0; int** MatrixA; int** MatrixB; int** MatrixC; clock_t startTime_For_Normal_Multipilication ; clock_t endTime_For_Normal_Multipilication ; clock_t startTime_For_Strassen ; clock_t endTime_For_Strassen ; time_t start,end; srand(time(0)); cout<<setw(45)<<"In the name of GOD"; cout<<endl<<setw(60)<<"Strassen Algorithm Implementation in C++ " <<endl<<endl<<setw(50)<<"By Seyyed Hossein Hasan Pour" <<endl<<setw(60)<<"Mazandaran University of Science and Technology" <<endl<<setw(40)<<"May 9 2010"; cout<<"\nPlease Enter your Matrix Size(must be in a power of two(eg:32,64,512,..): "; cin>>MatrixSize; int N = MatrixSize;//for readiblity. MatrixA = new int *[MatrixSize]; MatrixB = new int *[MatrixSize]; MatrixC = new int *[MatrixSize]; for (int i = 0; i < MatrixSize; i++) { MatrixA[i] = new int [MatrixSize]; MatrixB[i] = new int [MatrixSize]; MatrixC[i] = new int [MatrixSize]; } FillMatrix(MatrixA,MatrixB,MatrixSize); //*******************conventional multiplication test cout<<"Phase I started: "<< (startTime_For_Normal_Multipilication = clock());MUL(MatrixA,MatrixB,MatrixC,MatrixSize); cout<<"\nPhase I ended: "<< (endTime_For_Normal_Multipilication = clock());cout<<"\nMatrix Result... \n"; PrintMatrix(MatrixC,MatrixSize); //*******************Strassen multiplication test cout<<"\nMultiplication started: "<< (startTime_For_Strassen = clock());Strassen( N, MatrixA, MatrixB, MatrixC );cout<<"\nMultiplication: "<<(endTime_For_Strassen = clock());cout<<"\nMatrix Result... \n";PrintMatrix(MatrixC,MatrixSize);cout<<"Matrix size "<<MatrixSize;cout<<"\nNormal mode "<<(endTime_For_Normal_Multipilication - startTime_For_Normal_Multipilication)<<" Clocks.."<<(endTime_For_Normal_Multipilication - startTime_For_Normal_Multipilication)/CLOCKS_PER_SEC<<" Sec";cout<<"\nStrassen mode "<<(endTime_For_Strassen - startTime_For_Strassen)<<" Clocks.."<<(endTime_For_Strassen - startTime_For_Strassen)/CLOCKS_PER_SEC<<" Sec\n"; system("Pause"); return 0;}/*in order to be able to create a matrix without any limitaion in c++,one way is to create it using pointers.as you see by using a pointer to pointer strategy we can make a multi-dimensional Matrix of any size . The notation also makes us capable ofcreating a matrix with VARIABLE size at runtime ,meaning we can resizethe size of our matrix at runtime , shrink it or increase it , your choice.what we do is simple , first we make a pointer of pointer variable , thismeans that our first pointer will point to another pointer which againthis pointer ,points to sth else(we can make it point to an array) .int **A;will declare the variable , we now need to expand it .now make a pointer based array and allocate the memory dynamiclyA = new int *[desired_array_row];this gives us a one diminsional pointer based array,now you want a 2D array?big deal,lets make one.we use for() to achieve this goal , remember when i said we are going to makea variable which is a pointer of pointer ? which meant any location pointed to somewhere else, we made a pointer based array , a one diminsional one , just up there ,and you know this fatct that an array is consits of individual blocks right?and the fact that each block can be used just like a solo variable.so simply if we could writeA = new int *[any_size];cant we do it to all of our indiviual array blocks which are just like the solo variable ?so this means that if we could do it with A, and get an array , we can use the same methodto make different arrays for different block of the array we made in first place.we use for() to iterate through all of the blocks of the previously made array, andthen for each block we create a single array .for ( int i = 0; i < desired_array_row; i++)A[i] = new int [desired_column_size];after this for , we can enjoy our 2D array wich can be access like any ordinary array we know.just use the conventional notation for accessing array blocks for either reading or writing.( A[i][j])and remember to free the space we allocated for our 2D array at the end of the program .we do such a thing this way:for ( int i = 0; i < your_array_row; i++){ delete [] A[i];}delete[] A;.using this method you can make any N-diminsional array, you just need to use for with right iteration.*/int Strassen(int N, int **MatrixA, int **MatrixB, int **MatrixC){ int HalfSize = N/2; int newSize = N/2; if ( N <= 64 )//choosing the threshhold is extremely important, try N<=2 to see the result { MUL(MatrixA,MatrixB,MatrixC,N); } else {int** A11;int** A12;int** A21;int** A22;int** B11;int** B12;int** B21;int** B22;int** C11;int** C12;int** C21;int** C22;int** M1;int** M2;int** M3;int** M4;int** M5;int** M6;int** M7;int** AResult;int** BResult; //making a 1 diminsional pointer based array.A11 = new int *[newSize];A12 = new int *[newSize];A21 = new int *[newSize];A22 = new int *[newSize];B11 = new int *[newSize];B12 = new int *[newSize];B21 = new int *[newSize];B22 = new int *[newSize];C11 = new int *[newSize];C12 = new int *[newSize];C21 = new int *[newSize];C22 = new int *[newSize];M1 = new int *[newSize];M2 = new int *[newSize];M3 = new int *[newSize];M4 = new int *[newSize];M5 = new int *[newSize];M6 = new int *[newSize];M7 = new int *[newSize];AResult = new int *[newSize];BResult = new int *[newSize];int newLength = newSize; //making that 1 diminsional pointer based array , a 2D pointer based arrayfor ( int i = 0; i < newSize; i++){A11[i] = new int[newLength];A12[i] = new int[newLength];A21[i] = new int[newLength];A22[i] = new int[newLength];B11[i] = new int[newLength];B12[i] = new int[newLength];B21[i] = new int[newLength];B22[i] = new int[newLength];C11[i] = new int[newLength];C12[i] = new int[newLength];C21[i] = new int[newLength];C22[i] = new int[newLength];M1[i] = new int[newLength];M2[i] = new int[newLength];M3[i] = new int[newLength];M4[i] = new int[newLength];M5[i] = new int[newLength];M6[i] = new int[newLength];M7[i] = new int[newLength];AResult[i] = new int[newLength];BResult[i] = new int[newLength];}//splitting input Matrixes, into 4 submatrices each. for (int i = 0; i < N / 2; i++) { for (int j = 0; j < N / 2; j++) { A11[i][j] = MatrixA[i][j]; A12[i][j] = MatrixA[i][j + N / 2]; A21[i][j] = MatrixA[i + N / 2][j]; A22[i][j] = MatrixA[i + N / 2][j + N / 2]; B11[i][j] = MatrixB[i][j]; B12[i][j] = MatrixB[i][j + N / 2]; B21[i][j] = MatrixB[i + N / 2][j]; B22[i][j] = MatrixB[i + N / 2][j + N / 2]; } } //here we calculate M1..M7 matrices . //M1[][] ADD( A11,A22,AResult, HalfSize); ADD( B11,B22,BResult, HalfSize); Strassen( HalfSize, AResult, BResult, M1 ); //now that we need to multiply this , we use the strassen itself . //M2[][] ADD( A21,A22,AResult, HalfSize); //M2=(A21+A22)B11 Strassen(HalfSize, AResult, B11, M2); //Mul(AResult,B11,M2); //M3[][] SUB( B12,B22,BResult, HalfSize); //M3=A11(B12-B22) Strassen(HalfSize, A11, BResult, M3); //Mul(A11,BResult,M3); //M4[][] SUB( B21, B11, BResult, HalfSize); //M4=A22(B21-B11) Strassen(HalfSize, A22, BResult, M4); //Mul(A22,BResult,M4); //M5[][] ADD( A11, A12, AResult, HalfSize); //M5=(A11+A12)B22 Strassen(HalfSize, AResult, B22, M5); //Mul(AResult,B22,M5); //M6[][] SUB( A21, A11, AResult, HalfSize); ADD( B11, B12, BResult, HalfSize); //M6=(A21-A11)(B11+B12) Strassen( HalfSize, AResult, BResult, M6); //Mul(AResult,BResult,M6); //M7[][] SUB(A12, A22, AResult, HalfSize); ADD(B21, B22, BResult, HalfSize); //M7=(A12-A22)(B21+B22) Strassen(HalfSize, AResult, BResult, M7); //Mul(AResult,BResult,M7); //C11 = M1 + M4 - M5 + M7; ADD( M1, M4, AResult, HalfSize); SUB( M7, M5, BResult, HalfSize); ADD( AResult, BResult, C11, HalfSize); //C12 = M3 + M5; ADD( M3, M5, C12, HalfSize); //C21 = M2 + M4; ADD( M2, M4, C21, HalfSize); //C22 = M1 + M3 - M2 + M6; ADD( M1, M3, AResult, HalfSize); SUB( M6, M2, BResult, HalfSize); ADD( AResult, BResult, C22, HalfSize); //at this point , we have calculated the c11..c22 matrices, and now we are going to //put them together and make a unit matrix which would describe our resulting Matrix. for (int i = 0; i < N/2 ; i++) { for (int j = 0 ; j < N/2 ; j++) { MatrixC[i][j] = C11[i][j]; MatrixC[i][j + N / 2] = C12[i][j]; MatrixC[i + N / 2][j] = C21[i][j]; MatrixC[i + N / 2][j + N / 2] = C22[i][j]; } } // dont forget to free the space we alocated for matrices,for (int i = 0; i < newLength; i++){delete[] A11[i];delete[] A12[i];delete[] A21[i];delete[] A22[i];delete[] B11[i];delete[] B12[i];delete[] B21[i];delete[] B22[i];delete[] C11[i];delete[] C12[i];delete[] C21[i];delete[] C22[i];delete[] M1[i];delete[] M2[i];delete[] M3[i];delete[] M4[i];delete[] M5[i];delete[] M6[i];delete[] M7[i];delete[] AResult[i];delete[] BResult[i] ;}delete[] A11;delete[] A12;delete[] A21;delete[] A22;delete[] B11;delete[] B12;delete[] B21;delete[] B22;delete[] C11;delete[] C12;delete[] C21;delete[] C22;delete[] M1;delete[] M2;delete[] M3;delete[] M4;delete[] M5;delete[] M6;delete[] M7;delete[] AResult;delete[] BResult ; }//end of elsereturn 0;}int ADD(int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize ){ for ( int i = 0; i < MatrixSize; i++) { for ( int j = 0; j < MatrixSize; j++) { MatrixResult[i][j] = MatrixA[i][j] + MatrixB[i][j]; } }return 0;}int SUB(int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize ){ for ( int i = 0; i < MatrixSize; i++) { for ( int j = 0; j < MatrixSize; j++) { MatrixResult[i][j] = MatrixA[i][j] - MatrixB[i][j]; } }return 0;}int MUL( int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize ){ for (int i=0;i<MatrixSize ;i++) { for (int j=0;j<MatrixSize ;j++) { MatrixResult[i][j]=0; for (int k=0;k<MatrixSize ;k++) { MatrixResult[i][j]=MatrixResult[i][j]+MatrixA[i][k]*MatrixB[k][j]; } } }return 0;}void FillMatrix( int** MatrixA, int** MatrixB, int length){ for(int row = 0; row<length; row++) { for(int column = 0; column<length; column++) { MatrixB[row][column] = (MatrixA[row][column] = rand() %5); //matrix2[row][column] = rand() % 2;//ba hazfe in khat 50% afzayeshe soorat khahim dasht } }}void PrintMatrix(int **MatrixA,int MatrixSize){cout<<endl; for(int row = 0; row<MatrixSize; row++){for(int column = 0; column<MatrixSize; column++){cout<<MatrixA[row][column]<<"\t";if ((column+1)%((MatrixSize)) == 0)cout<<endl;}} cout<<endl;}
0 0
- 算法导论--------------Strassen矩阵乘法
- 【算法导论】矩阵乘法strassen算法
- 算法导论-矩阵乘法-strassen算法
- STRASSEN算法(矩阵乘法)
- strassen矩阵乘法算法
- 算法导论 第四章矩阵乘法的Strassen算法
- 算法导论之四矩阵乘法的Strassen算法
- 《算法导论》学习笔记之Chapter4.2矩阵乘法Strassen
- strassen算法(矩阵乘法)
- strassen算法优化矩阵乘法
- 矩阵乘法 之 strassen 算法
- 矩阵乘法的Strassen算法
- Strassen矩阵乘法算法实现
- Strassen算法之矩阵乘法
- 矩阵乘法的Strassen算法
- 贪心算法-Strassen矩阵乘法
- 算法导论C语言实现: 分治策略 -- 矩阵乘法的Strassen算法
- 《算法导论》学习心得(二)—— 矩阵乘法之Strassen算法
- Qt5中通过信号槽传递多个参数
- c#里xml格式文件的新增、修改、删除操作方法
- poj 2754:八皇后
- 数据库层adcfgclone.pl,执行adcrdb.sh出错
- 手机卫士知识点之------手机防盗功能1
- 【算法导论】矩阵乘法strassen算法
- Linux学习笔记——例说makefile 头文件查找路径
- C#跨线程访问
- ratingbar介绍及其例子
- form 只有 有 name 的元素 才会被提交
- 12306火车票订票失败!您的身份信息未经核验,一般人是不能订票的,我订了,但是没成功。。。
- poj 2698:八皇后问题
- Qt5.3.1 Mingw32 编译oracle 11g 驱动
- java4android第四十五集hashcode()与tostring()