caffe源码中关于矩阵运算的一些函数

来源:互联网 发布:金正恩的发型 知乎 编辑:程序博客网 时间:2024/06/08 07:15

caffe中的一些矩阵运算函数,定义在文件caffe\src\caffe\util\math_function.cpp和文件caffe\include\caffe\util\math_function.hpp里


它们实际上是cblas的精简版,因此,math_function.hpp的头文件一般包含cblas相关文件。


当我们在某些文件中出现关于矩阵运算的函数的时候,首先在math_function.cpp里找到该函数名称,查看该函数调用的cblas函数名,然后,查看cblas手册,得到该函数的具体描述。


矩阵中的数据存储在一片连续的内存中,即它的数据结构实际上是vector。cblas中的函数通过指向这些内存首地址的指针来访问矩阵中的数据。为了让这些vector具有矩阵的特性,还要向cblas中的函数传递用来描述这些矩阵形状、尺寸等信息的参数,例如矩阵的行数、列数、是否转置等等。


下面是math_function.cpp文件中的一段段代码:

</pre><pre style="margin-top: 0px; margin-bottom: 0px;"><span style=" color:#808000;">template</span><>

void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C) {
  int lda = (TransA == CblasNoTrans) ? K : M;
  int ldb = (TransB == CblasNoTrans) ? N : K;
  cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
      ldb, beta, C, N);
}

该函数实现的矩阵运算是

C=alpha*op(A)op(B)+beta*C

其中:

1) alpha和beta标量

2) 如果CblasNoTrans==1,op(A)==A,否则,op(A)==A'。op(B)的意思类推。

3) M、N、K描述的是矩阵的形状,其中,op(A)的形状是M*K,op(B)的形状是K*N,C的形状是M*N。

4) 关于lda和ldb,这里给出了解释:http://icl.cs.utk.edu/lapack-forum/viewtopic.php?p=661&sid=67c66465dedfcbb6e0612cca7647698f

把解释贴出来:

Suppose that you have a matrix A of size 100x100 which is stored in an array 100x100. In this case LDA is the same as N. Now suppose that you want to work only on the submatrix A(91:100 , 1:100); in this case the number of rows is 10 but LDA=100. Assuming the fortran column-major ordering (which is the case in LAPACK),the LDA is used to define the distance in memory between elements of two consecutive columns which have the same row index. If you call B = A(91:100 , 1:100) then B(1,1) and B(1,2) are 100 memory locations far from each other. 

注意,上述是根据Fortran语言来举例子的,Fortran语言中矩阵是列存储,也就是矩阵的同一列在一片连续内存中,而在C\C++中,是按行存储的。


查看其它函数功能的方法类似,就是先在math_function.cpp中看它调用了哪个cblas中的函数,然后翻看cblas手册查看该函数的功能。

0 0
原创粉丝点击