numpy.einsum-学习笔记

来源:互联网 发布:淘宝什么是自主访问 编辑:程序博客网 时间:2024/05/21 21:36

numpy.einsum-学习笔记

元素

函数呼叫的简单形式如下

C = np.einsum("ijk,km",A,B)

或者

C = np.einsum("ijk,km->m",A,B)

我们可以这样来看,np.einsum 的第一个参数是一个string,由两个substring组成,由逗号分隔。第二个和第三个参数都是np.array

这个函数的功能呢,就是让我们定义AB 两个数组的运算。

那么什么是定义运算呢?基本上我们可以理解为定义下面一句话(设AB运算的结果是C),C的某一个元素是由A的某一些元素与B的某一些元素通过什么操作得到的。

那要定义运算,首先我们要有办法指出每个元素在AB 中的位置(需要indices)。(然后才能定义运算:例如,让A的某个位置的元素和B的某个位置的元素相乘/相加等等)

注意,这边的元素要求是scalar!(例如,3d-array需要3个维度对应到一个实数项(scalar),而2个维度对应的是vector,而1个维度对应的是matrix)

那第一参数我们就在做上面这件事,, 分割的两个substring定义了所需要的维度指代(先只看->的左边)每一个字母对应相应的维度(或者axis,从小到大)),例如上面的代码中,第一个substringijk 对应到第一个数组A 的indices,那么定义运算的时候指A 的scalar就是: A[i,j,k] 表示 ,同理我们定义运算的时候指B的scalar就是B[k,m] 的indices

没有->

此时的运算规则很简单,就是将两个substring重合的indices去做和(这样就去掉了这个维度),其余的保留作为index,例如

C = np.einsum("ijk,km",A,B)

就是

c[i,j,m]=ka[i,j,k]b[k,m]

请注意上面c[i,j,m], a[i,j,k], b[k,m] 都是scalar(实数)。则c[i,j,m] 想要表达的是这样一件事:C 是这样一个东西(Tensor):它的(i,j,m) 位置(此时C 是一个3维数组)的元素项由上式右边计算得到。

此外,上式也隐含这样的条件:A 的第三维度和B 的第一维度的大小应该相同。

->

-> 和上面略有不同。-> 右边定义了结果的样子(因此也定义了运算)例如,

C = np.einsum("ijk,km->k",A,B)

此时,上面的代码要表达的是:C 的第(k) 项(此时C 是一个1维数组)是这样得到的

c[k]=i,j,ma[i,j,k]b[k,m]

例子

A = np.array([    [[1, 2], [3, 4]],    [[5, 6], [7, 8]]])B = np.array([    [9, 10],    [11, 12],    [13, 14]])

Case 1

C = np.einsum('ijk,mn', A, B)

首先没有->,
其次此时两个substring没有重合的index,
则结果

c[i,j,k,m,n]=A[i,j,k]B[m,n]

Case 2

C = np.einsum('ijk,jk', A, B)

首先没有->,
其次此时两个substring有重合indices: j, k
理论上结果

c[i]=jkA[i,j,k]B[j,k]

然而A 的shape是(2,2,2) 而B的shape是(3,2) 我们发现j 对应在AB中维度的大小不一致,所以会报错

Case 3

C = np.einsum('ijk,mk', A, B)

首先没有->,
其次此时两个substring有重合index: k
理论上结果

c[i,j,m]=kA[i,j,k]B[m,k]

k 对应在AB中维度的大小一致(都为2),没问题,结果是

>>>[[[ 29  35  41]     [ 67  81  95]]    [[105 127 149]     [143 173 203]]]

简单验证下:

c[0,0,0]=k=01A[0,0,k]B[0,k]=A[0,0,0]B[0,0]+A[0,0,1]B[0,1]

根据前面的
A[0,0,0]=1B[0,0]=9A[0,0,1]=2B[0,10]

结果是
c[0,0,0]=19+210=29

Case 4

C = np.einsum('ijk,mk->i', A, B)

首先有->,则-> 右边定义了结果的样子,
则运算

c[i]=j2k2m2A[i,j,k]B[m,k]

结果是,

>>>[348 900]

请自行验证

Case 5

C = np.einsum('ijk,mk->k', A, B)

则运算

c[k]=i2j2m2A[i,j,k]B[m,k]

结果是,

>>>[528 720]

请自行验证

原创粉丝点击