Strassen算法之矩阵乘法

来源:互联网 发布:b超数据算胎儿体重软件 编辑:程序博客网 时间:2024/04/27 05:27

Strassen算法之矩阵乘法

问题:给定两个n-by-n矩阵A和B,计算C=AB;
分析:如果采用一般方法求解矩阵C,我们根据乘法定义知道C中每个元素都需要O(n)次乘法,总共有n^2个元素,所以时间复杂度是O(n^3)。当n很大时,这个时间是非常久的。那我们有什么快速的方法计算矩阵乘法呢?采用Divide、Conquer and Combine思想,把矩阵A、B、C分别画成4个小矩阵,这样.把每个问题分成8个子问题和4次加法。到此,这是分治策略的方法,时间复杂度也是O(n^3)。但实际上我们不需要计算8个子问题,值需要计算7个矩阵的结果就能表示出来。

(C11C21C12C22)=(a11a21a12a22)(b11b21b12b22)

P1=a11(b12b22)
P2=(a11+a12)b22
P3=(a21+a22)b11
P4=a22(b21b11)
P5=(a11+a22)(b11+b22)
P6=(a12a22)(b21+b22)
P7=(a11a21)(b11+b12)

而我们的C矩阵可以如下表示:
C11=P4+P5+P6P2
C12=P1+P2
C21=P3+P4
C22=P1+P5P3P7

这就是Strassen算法求矩阵乘法思想,时间复杂度降到O(n^2.807)。具体实现过程如下:

/*******************************************************Description:利用Strassen算法求矩阵乘法Time complexity:O(logN^2.81) Author:Robert.Tianyi********************************************************/#include<stdio.h>#include<stdlib.h>int * Strassen_MatrixMultiplication(int *a,int *b,int size);/*动态开辟空间*/ int *creat2LayerPointer(int row,int column){    int *a=NULL;    a=(int*)malloc(sizeof(int)*row*column);//开辟row*column个空间     return a; }/*矩阵加法*/ int * MatrixAdd(int *a,int *b,int size){     int *result=NULL;    int i,j;    result=creat2LayerPointer(size,size);    for(j=0;j<size*size;j++)        result[j]=a[j]+b[j];    return result;  }/*矩阵减法*/ int * MatrixSub(int *a,int *b,int size){    int *result=NULL;    int i,j;    result=creat2LayerPointer(size,size);    for(j=0;j<size*size;j++)            result[j]=a[j]-b[j];    return result;  }void main(){    int a[64],b[64];    int *pointera=NULL,*pointerb=NULL;    int *c=NULL;    int i,j;    c=creat2LayerPointer(8,8);    for(i=0;i<64;i++){        a[i]=i;        b[i]=i+1;     }    pointera=&a[0];    pointerb=&b[0];             printf("a矩阵如下:\n");    for(i=0;i<64;i++){        if(i%8==0)            printf("\n");        printf("%4d ",a[i]);    }    printf("\nb矩阵如下:\n");    for(i=0;i<64;i++){        if(i%8==0)            printf("\n");        printf("%4d ",b[i]);    }    printf("\n\n采用Strassen算法,计算a*b结果如下:\n");    c=Strassen_MatrixMultiplication(pointera,pointerb,8);     for(i=0;i<64;i++){        if(i%8==0)            printf("\n");        printf("%8d",*(c++));       }    free(c);    }int * Strassen_MatrixMultiplication(int *a,int *b,int size){    int temp_size;    int *a11=NULL,*a12=NULL,*a21=NULL,*a22=NULL,*b11=NULL,*b12=NULL,*b21=NULL,*b22=NULL;    int *s1=NULL,*s2=NULL,*s3=NULL,*s4=NULL,*s5=NULL,*s6=NULL,*s7=NULL,*s8=NULL,*s9=NULL,*s10=NULL;    int *P1=NULL,*P2=NULL,*P3=NULL,*P4=NULL,*P5=NULL,*P6=NULL,*P7=NULL;    int *c11=NULL,*c12=NULL,*c21=NULL,*c22=NULL;    int *C=NULL,*temp_C=NULL;    int i,j;    temp_size=size/2;    if(size==2){//递归停止条件         C=creat2LayerPointer(size,size);        C[0]=a[0]*b[0]+a[1]*b[2];        C[1]=a[0]*b[1]+a[1]*b[3];        C[2]=a[2]*b[0]+a[3]*b[2];        C[3]=a[2]*b[1]+a[3]*b[3];        /*释放指针*/         free(a11);  free(a12);  free(a21);  free(a22);        free(b11);  free(b12);  free(b21);  free(b22);        free(c11);  free(c12);  free(c21);  free(c22);        free(s1);free(s2);free(s3);free(s4);free(s5);        free(s6);free(s6);free(s8);free(s9);free(s10);          free(P1);free(P2);free(P3);free(P4);free(P5);free(P6);free(P7);     //  temp_C=C;    //  free(C);        return C;    }    else{        /*动态给矩阵a,b,c开辟空间*/         a11=creat2LayerPointer(temp_size,temp_size);        a12=creat2LayerPointer(temp_size,temp_size);        a21=creat2LayerPointer(temp_size,temp_size);        a22=creat2LayerPointer(temp_size,temp_size);        b11=creat2LayerPointer(temp_size,temp_size);        b12=creat2LayerPointer(temp_size,temp_size);        b21=creat2LayerPointer(temp_size,temp_size);        b22=creat2LayerPointer(temp_size,temp_size);        c11=creat2LayerPointer(temp_size,temp_size);        c12=creat2LayerPointer(temp_size,temp_size);        c21=creat2LayerPointer(temp_size,temp_size);        c22=creat2LayerPointer(temp_size,temp_size);        C=creat2LayerPointer(size,size);        s1=creat2LayerPointer(temp_size,temp_size);        s2=creat2LayerPointer(temp_size,temp_size);        s3=creat2LayerPointer(temp_size,temp_size);        s4=creat2LayerPointer(temp_size,temp_size);        s5=creat2LayerPointer(temp_size,temp_size);        s6=creat2LayerPointer(temp_size,temp_size);        s7=creat2LayerPointer(temp_size,temp_size);        s8=creat2LayerPointer(temp_size,temp_size);        s9=creat2LayerPointer(temp_size,temp_size);        s10=creat2LayerPointer(temp_size,temp_size);        P1=creat2LayerPointer(temp_size,temp_size);        P2=creat2LayerPointer(temp_size,temp_size);        P3=creat2LayerPointer(temp_size,temp_size);        P4=creat2LayerPointer(temp_size,temp_size);        P5=creat2LayerPointer(temp_size,temp_size);        P6=creat2LayerPointer(temp_size,temp_size);        P7=creat2LayerPointer(temp_size,temp_size);        /*矩阵a b进行分割成4个小矩阵*/         for(i=0;i<temp_size;i++)            for(j=0;j<temp_size;j++){                a11[i*temp_size+j]=a[i*size+j];                a12[i*temp_size+j]=a[i*size+j+temp_size]    ;                a21[i*temp_size+j]=a[2*temp_size*temp_size+i*size+j];                a22[i*temp_size+j]=a[2*temp_size*temp_size+i*size+j+temp_size];                b11[i*temp_size+j]=b[i*size+j];                b12[i*temp_size+j]=b[i*size+j+temp_size];                b21[i*temp_size+j]=b[2*temp_size*temp_size+i*size+j];                b22[i*temp_size+j]=b[2*temp_size*temp_size+i*size+j+temp_size];            }        s1=MatrixSub(b12,b22,temp_size);        s2=MatrixAdd(a11,a12,temp_size);        s3=MatrixAdd(a21,a22,temp_size);        s4=MatrixSub(b21,b11,temp_size);        s5=MatrixAdd(a11,a22,temp_size);        s6=MatrixAdd(b11,b22,temp_size);        s7=MatrixSub(a12,a22,temp_size);        s8=MatrixAdd(b21,b22,temp_size);        s9=MatrixSub(a11,a21,temp_size);        s10=MatrixAdd(b11,b12,temp_size);        /*迭代*/         P1=Strassen_MatrixMultiplication(a11,s1,temp_size);        P2=Strassen_MatrixMultiplication(s2,b22,temp_size);        P3=Strassen_MatrixMultiplication(s3,b11,temp_size);        P4=Strassen_MatrixMultiplication(a22,s4,temp_size);        P5=Strassen_MatrixMultiplication(s5,s6,temp_size);        P6=Strassen_MatrixMultiplication(s7,s8,temp_size);        P7=Strassen_MatrixMultiplication(s9,s10,temp_size);        c11=MatrixAdd(MatrixSub(MatrixAdd(P5,P4,temp_size),P2,temp_size),P6,temp_size);        c12=MatrixAdd(P1,P2,temp_size);        c21=MatrixAdd(P3,P4,temp_size);        c22=MatrixSub(MatrixSub(MatrixAdd(P5,P1,temp_size),P3,temp_size),P7,temp_size);        /*将4个小块矩阵合并到C*/         for(i=0;i<temp_size;i++){            for(j=0;j<temp_size;j++){                C[i*size+j]=c11[i*temp_size+j];                C[i*size+j+temp_size]=c12[i*temp_size+j];                C[2*temp_size*temp_size+i*size+j]=c21[i*temp_size+j];                C[2*temp_size*temp_size+i*size+j+temp_size]=c22[i*temp_size+j];             }           }        return C;       }}
0 0