c++的矩阵乘法加速trick

来源:互联网 发布:base64加密 c语言 编辑:程序博客网 时间:2024/06/01 09:10

c++的矩阵乘法加速trick

最近读RNNLM的源代码,发现其实现矩阵乘法时使用了一个trick,这里描述一下这个trick。

首先是正常版的矩阵乘法(其实是矩阵乘向量)

void matrixXvector(float* destvect, float* srcmatrix, int srcmatrix_rownum, int srcmatrix_colnum, float* srcvect, int srcvect_size){    for(int row=0;row<srcmatrix_rownum;++row){        destvect[row]=0;        for(int col=0;col<srcmatrix_colnum;++col){            destvect[row]+=srcmatrix[row*srcmatrix_colnum+col]*srcvect[col];        }    }}

就是最简单的for循环,逐行逐列遍历。

接下来是RNNLM中实现的trick版本

void matrixXvector2(float* destvect, float* srcmatrix, int srcmatrix_rownum, int srcmatrix_colnum, float* srcvect, int srcvect_size){    int row, col;    float val1, val2, val3, val4;    float val5, val6, val7, val8;        for(row=0;row<srcmatrix_rownum/8;++row){        val1 = 0;        val2 = 0;        val3 = 0;        val4 = 0;        val5 = 0;        val6 = 0;        val7 = 0;        val8 = 0;                for(col=0;col<srcmatrix_colnum;++col){            val1+=srcmatrix[(row*8+0)*srcmatrix_colnum+col]*srcvect[col];            val2+=srcmatrix[(row*8+1)*srcmatrix_colnum+col]*srcvect[col];            val3+=srcmatrix[(row*8+2)*srcmatrix_colnum+col]*srcvect[col];            val4+=srcmatrix[(row*8+3)*srcmatrix_colnum+col]*srcvect[col];            val5+=srcmatrix[(row*8+4)*srcmatrix_colnum+col]*srcvect[col];            val6+=srcmatrix[(row*8+5)*srcmatrix_colnum+col]*srcvect[col];            val7+=srcmatrix[(row*8+6)*srcmatrix_colnum+col]*srcvect[col];            val8+=srcmatrix[(row*8+7)*srcmatrix_colnum+col]*srcvect[col];        }                destvect[row*8+0]+=val1;        destvect[row*8+1]+=val2;        destvect[row*8+2]+=val3;        destvect[row*8+3]+=val4;        destvect[row*8+4]+=val5;        destvect[row*8+5]+=val6;        destvect[row*8+6]+=val7;        destvect[row*8+7]+=val8;            }        for(row=row*8;row<srcmatrix_rownum;++row){        for(col=0;col<srcmatrix_colnum;++col){            destvect[row]+=srcmatrix[row*srcmatrix_colnum+col]*srcvect[col];            }    }}

对比普通版,trick版把遍历行的for循环分成了8份,同时进行列遍历。

实际测试中,这个trick版比普通版快了接近2倍~这是编译器优化造成的么……?

参考:http://www.cnblogs.com/plwang1990/p/4139357.html

0 0
原创粉丝点击