Strassen算法之矩阵乘法
来源:互联网 发布:b超数据算胎儿体重软件 编辑:程序博客网 时间:2024/04/27 05:27
Strassen算法之矩阵乘法
问题:给定两个n-by-n矩阵A和B,计算C=AB;
分析:如果采用一般方法求解矩阵C,我们根据乘法定义知道C中每个元素都需要O(n)次乘法,总共有n^2个元素,所以时间复杂度是O(n^3)。当n很大时,这个时间是非常久的。那我们有什么快速的方法计算矩阵乘法呢?采用Divide、Conquer and Combine思想,把矩阵A、B、C分别画成4个小矩阵,这样.把每个问题分成8个子问题和4次加法。到此,这是分治策略的方法,时间复杂度也是O(n^3)。但实际上我们不需要计算8个子问题,值需要计算7个矩阵的结果就能表示出来。
而我们的C矩阵可以如下表示:
这就是Strassen算法求矩阵乘法思想,时间复杂度降到O(n^2.807)。具体实现过程如下:
/*******************************************************Description:利用Strassen算法求矩阵乘法Time complexity:O(logN^2.81) Author:Robert.Tianyi********************************************************/#include<stdio.h>#include<stdlib.h>int * Strassen_MatrixMultiplication(int *a,int *b,int size);/*动态开辟空间*/ int *creat2LayerPointer(int row,int column){ int *a=NULL; a=(int*)malloc(sizeof(int)*row*column);//开辟row*column个空间 return a; }/*矩阵加法*/ int * MatrixAdd(int *a,int *b,int size){ int *result=NULL; int i,j; result=creat2LayerPointer(size,size); for(j=0;j<size*size;j++) result[j]=a[j]+b[j]; return result; }/*矩阵减法*/ int * MatrixSub(int *a,int *b,int size){ int *result=NULL; int i,j; result=creat2LayerPointer(size,size); for(j=0;j<size*size;j++) result[j]=a[j]-b[j]; return result; }void main(){ int a[64],b[64]; int *pointera=NULL,*pointerb=NULL; int *c=NULL; int i,j; c=creat2LayerPointer(8,8); for(i=0;i<64;i++){ a[i]=i; b[i]=i+1; } pointera=&a[0]; pointerb=&b[0]; printf("a矩阵如下:\n"); for(i=0;i<64;i++){ if(i%8==0) printf("\n"); printf("%4d ",a[i]); } printf("\nb矩阵如下:\n"); for(i=0;i<64;i++){ if(i%8==0) printf("\n"); printf("%4d ",b[i]); } printf("\n\n采用Strassen算法,计算a*b结果如下:\n"); c=Strassen_MatrixMultiplication(pointera,pointerb,8); for(i=0;i<64;i++){ if(i%8==0) printf("\n"); printf("%8d",*(c++)); } free(c); }int * Strassen_MatrixMultiplication(int *a,int *b,int size){ int temp_size; int *a11=NULL,*a12=NULL,*a21=NULL,*a22=NULL,*b11=NULL,*b12=NULL,*b21=NULL,*b22=NULL; int *s1=NULL,*s2=NULL,*s3=NULL,*s4=NULL,*s5=NULL,*s6=NULL,*s7=NULL,*s8=NULL,*s9=NULL,*s10=NULL; int *P1=NULL,*P2=NULL,*P3=NULL,*P4=NULL,*P5=NULL,*P6=NULL,*P7=NULL; int *c11=NULL,*c12=NULL,*c21=NULL,*c22=NULL; int *C=NULL,*temp_C=NULL; int i,j; temp_size=size/2; if(size==2){//递归停止条件 C=creat2LayerPointer(size,size); C[0]=a[0]*b[0]+a[1]*b[2]; C[1]=a[0]*b[1]+a[1]*b[3]; C[2]=a[2]*b[0]+a[3]*b[2]; C[3]=a[2]*b[1]+a[3]*b[3]; /*释放指针*/ free(a11); free(a12); free(a21); free(a22); free(b11); free(b12); free(b21); free(b22); free(c11); free(c12); free(c21); free(c22); free(s1);free(s2);free(s3);free(s4);free(s5); free(s6);free(s6);free(s8);free(s9);free(s10); free(P1);free(P2);free(P3);free(P4);free(P5);free(P6);free(P7); // temp_C=C; // free(C); return C; } else{ /*动态给矩阵a,b,c开辟空间*/ a11=creat2LayerPointer(temp_size,temp_size); a12=creat2LayerPointer(temp_size,temp_size); a21=creat2LayerPointer(temp_size,temp_size); a22=creat2LayerPointer(temp_size,temp_size); b11=creat2LayerPointer(temp_size,temp_size); b12=creat2LayerPointer(temp_size,temp_size); b21=creat2LayerPointer(temp_size,temp_size); b22=creat2LayerPointer(temp_size,temp_size); c11=creat2LayerPointer(temp_size,temp_size); c12=creat2LayerPointer(temp_size,temp_size); c21=creat2LayerPointer(temp_size,temp_size); c22=creat2LayerPointer(temp_size,temp_size); C=creat2LayerPointer(size,size); s1=creat2LayerPointer(temp_size,temp_size); s2=creat2LayerPointer(temp_size,temp_size); s3=creat2LayerPointer(temp_size,temp_size); s4=creat2LayerPointer(temp_size,temp_size); s5=creat2LayerPointer(temp_size,temp_size); s6=creat2LayerPointer(temp_size,temp_size); s7=creat2LayerPointer(temp_size,temp_size); s8=creat2LayerPointer(temp_size,temp_size); s9=creat2LayerPointer(temp_size,temp_size); s10=creat2LayerPointer(temp_size,temp_size); P1=creat2LayerPointer(temp_size,temp_size); P2=creat2LayerPointer(temp_size,temp_size); P3=creat2LayerPointer(temp_size,temp_size); P4=creat2LayerPointer(temp_size,temp_size); P5=creat2LayerPointer(temp_size,temp_size); P6=creat2LayerPointer(temp_size,temp_size); P7=creat2LayerPointer(temp_size,temp_size); /*矩阵a b进行分割成4个小矩阵*/ for(i=0;i<temp_size;i++) for(j=0;j<temp_size;j++){ a11[i*temp_size+j]=a[i*size+j]; a12[i*temp_size+j]=a[i*size+j+temp_size] ; a21[i*temp_size+j]=a[2*temp_size*temp_size+i*size+j]; a22[i*temp_size+j]=a[2*temp_size*temp_size+i*size+j+temp_size]; b11[i*temp_size+j]=b[i*size+j]; b12[i*temp_size+j]=b[i*size+j+temp_size]; b21[i*temp_size+j]=b[2*temp_size*temp_size+i*size+j]; b22[i*temp_size+j]=b[2*temp_size*temp_size+i*size+j+temp_size]; } s1=MatrixSub(b12,b22,temp_size); s2=MatrixAdd(a11,a12,temp_size); s3=MatrixAdd(a21,a22,temp_size); s4=MatrixSub(b21,b11,temp_size); s5=MatrixAdd(a11,a22,temp_size); s6=MatrixAdd(b11,b22,temp_size); s7=MatrixSub(a12,a22,temp_size); s8=MatrixAdd(b21,b22,temp_size); s9=MatrixSub(a11,a21,temp_size); s10=MatrixAdd(b11,b12,temp_size); /*迭代*/ P1=Strassen_MatrixMultiplication(a11,s1,temp_size); P2=Strassen_MatrixMultiplication(s2,b22,temp_size); P3=Strassen_MatrixMultiplication(s3,b11,temp_size); P4=Strassen_MatrixMultiplication(a22,s4,temp_size); P5=Strassen_MatrixMultiplication(s5,s6,temp_size); P6=Strassen_MatrixMultiplication(s7,s8,temp_size); P7=Strassen_MatrixMultiplication(s9,s10,temp_size); c11=MatrixAdd(MatrixSub(MatrixAdd(P5,P4,temp_size),P2,temp_size),P6,temp_size); c12=MatrixAdd(P1,P2,temp_size); c21=MatrixAdd(P3,P4,temp_size); c22=MatrixSub(MatrixSub(MatrixAdd(P5,P1,temp_size),P3,temp_size),P7,temp_size); /*将4个小块矩阵合并到C*/ for(i=0;i<temp_size;i++){ for(j=0;j<temp_size;j++){ C[i*size+j]=c11[i*temp_size+j]; C[i*size+j+temp_size]=c12[i*temp_size+j]; C[2*temp_size*temp_size+i*size+j]=c21[i*temp_size+j]; C[2*temp_size*temp_size+i*size+j+temp_size]=c22[i*temp_size+j]; } } return C; }}
0 0
- 矩阵乘法 之 strassen 算法
- Strassen算法之矩阵乘法
- STRASSEN算法(矩阵乘法)
- strassen矩阵乘法算法
- strassen算法(矩阵乘法)
- strassen算法优化矩阵乘法
- 算法导论--------------Strassen矩阵乘法
- 矩阵乘法的Strassen算法
- Strassen矩阵乘法算法实现
- 矩阵乘法的Strassen算法
- 贪心算法-Strassen矩阵乘法
- 算法导论之四矩阵乘法的Strassen算法
- 算法重拾之路——strassen矩阵乘法
- 《算法导论》学习笔记之Chapter4.2矩阵乘法Strassen
- 【算法导论】矩阵乘法strassen算法
- 算法导论-矩阵乘法-strassen算法
- 矩阵乘法(Strassen算法/C++实现)
- 使用python实现Strassen矩阵乘法算法
- 说说javaScript中要注意的问题
- Ubuntu更新出现 The system is running in low-graphics mode解决
- 第五周 项目一(2)矩形面积(Raptor)
- 输出员工信息并计算员工的工资
- android debug bridge的环境配置
- Strassen算法之矩阵乘法
- Python的查车票小工具
- 数据结构-广义表(GeneralizedList)实现
- golang生产者与消费者
- Jquery - map、has、each 方法简述
- creating server tcp listening socket 127.0.0.1:6379: bind No error
- altera 全局时钟资源的利用
- Leetcode-8. String to Integer (atoi)
- gstreamer成功安装后 出现 “未定义的引用”的错误的解决方案