矩阵乘法 模板函数的实现 可以处理多维矩阵 c++

来源:互联网 发布:cssci数据库怎么查询 编辑:程序博客网 时间:2024/05/16 05:49

文中计算结果是与Matlab对比过的。但是,如有发现错误谢谢告之。

函数实现:

/*********************************************************************   函数功能:(aM*bN的矩阵mR) = (aM*aN的矩阵mA) x (bM*bN的矩阵mB)。**前行后列(前高后宽)即为相乘结果矩阵的大小:**前矩阵的行数(高aM) x 后矩阵的列数(宽bN)。**********************************************************************/template <typename T1, typename T2, typename T3>int Matrix_Mult(T1* mA,//矩阵Aint aM,//矩阵A的行数(高)int aN,//矩阵A的列数(宽)T2* mB,//矩阵Bint bM,//矩阵B的行数(高)int bN,//矩阵B的列数(宽)T3* mR,//矩阵R=矩阵A*矩阵Bint chan=1  //例如RGB24可以看成是R、G、B 3个通道){//前矩阵的宽(列数)必须等于后矩阵的高(行数)if (aN != bM || mA==0 || mB==0 || mR==0){return -1;}int iTemp0 = 0;int index0 = 0;int index1 = 0;//图像处理一般四通道就够了,比如RGB32T3* sum[10]= {0};for (int i=0; i<chan; i++){sum[i] = (T3*)malloc(bN*sizeof(T3));}//循环处理矩阵A每一行for (int am=0; am<aM; am++){    for (int i=0; i<chan; i++){memset(sum[i], 0, sizeof(T3)*bN);}//矩阵Bfor (int bm=0; bm<bM; bm++){iTemp0 = bm*bN;        for (int bn=0; bn<bN; bn++){index0 = iTemp0 + bn;    index1 = am*aN + bm;  //计算矩阵A对应位置        if (bm == 0){//计算各个通道for (int i=0; i<chan; i++){sum[i][bn*chan+i] = (T3)(mB[index0*chan+i]*mA[index1*chan+i]);}}else{//计算各个通道for (int i=0; i<chan; i++){sum[i][bn*chan+i] += (T3)(mB[index0*chan+i]*mA[index1*chan+i]);}}                }}    //计算各个通道for (int i=0; i<chan; i++){memcpy(mR+am*bN*chan+i, sum[i], sizeof(T3)*bN);}}//返回结果矩阵的成员数return aM*bN*chan;}
函数调用:
int _tmain(int argc, _TCHAR* argv[]){int irt = 0;//-----------  一维数组 --------------------------------------cout<<"------------------  矩阵相乘1 -------------------------\n";double mA[16] = {1,2,3,4,  5,6,7,8,  9,10,11,12,  13,14,15,16};double mB[16] = {1,2,3,4,  5,6,7,8,  9,10,11,12,  13,14,15,16};double mR[16] = {0};irt = Matrix_Mult<double,double, double>(mA, 4, 4, mB, 4, 4, mR);for (int i=0 ;i<irt; i++){cout<<mR[i]<<"  ";if ((i+1)%4 == 0){cout<<endl;}}cout<<endl;cout<<"------------------  矩阵相乘2 -------------------------\n";int mA2[4]= {1,1,  2,0};int mB2[6]= {0,2,3,  1,1,2};double mR2[6]= {0};irt =  Matrix_Mult<int, int, double>(mA2, 2, 2, mB2, 2, 3, mR2);for (int i=0 ;i<6; i++){cout<<mR2[i]<<"  ";if ((i+1)%3 == 0){cout<<endl;}}cout<<endl;cout<<"------------------  矩阵相乘3 -------------------------\n";float mA3[4]= {1.1f, 2.2f, 3.3f};int mB3[6]= {1,2, 3,4, 5,6};double mR3[6]= {0};irt = Matrix_Mult<float, int, double>(mA3, 1, 3, mB3, 3, 2, mR3);for (int i=0 ;i<irt; i++){cout<<mR3[i]<<"  ";}cout<<endl<<endl;//-----------  二维数组 --------------------------------------cout<<"------------------  矩阵相乘4 -------------------------\n";double mA1[4][4] = {{1,2,3,4} , {5,6,7,8} , {9,10,11,12} , {13,14,15,16}};double mB1[4][4] = {{1,2,3,4} , {5,6,7,8}  ,{9,10,11,12} , {13,14,15,16}};double mR1[4][4] = {0};Matrix_Mult<double,double, double>(&mA1[0][0], 4, 4, &mB1[0][0], 4, 4, &mR1[0][0]);for (int n=0; n<4; n++){for (int m=0; m<4; m++){cout<<mR1[n][m]<<"  ";}cout<<endl;}cout<<endl;return 0;}
运行结果: