2017-10-12 75 views
2

我具有相同的尺寸a的兩個3 d陣列和b矩陣乘以每對2-d陣列沿着第一維度與einsum

np.random.seed([3,14159]) 
a = np.random.randint(10, size=(4, 3, 2)) 
b = np.random.randint(10, size=(4, 3, 2)) 

print(a) 

[[[4 8] 
    [1 1] 
    [9 2]] 

[[8 1] 
    [4 2] 
    [8 2]] 

[[8 4] 
    [9 4] 
    [3 4]] 

[[1 5] 
    [1 2] 
    [6 2]]] 

print(b) 

[[[7 7] 
    [1 1] 
    [7 8]] 

[[7 4] 
    [8 0] 
    [0 9]] 

[[3 8] 
    [7 7] 
    [2 6]] 

[[3 1] 
    [9 3] 
    [0 5]]] 

我想從

a[0] 

[[4 8] 
[1 1] 
[9 2]] 

而且從b

b[0] 

[[7 7] 
[1 1] 
[7 8]] 

的第一個,返回此

a[0].T.dot(b[0]) 

[[ 92 101] 
[ 71 73]] 

但我想這樣做,在整個第一個維度。我想我可以使用np.einsum

np.einsum('abc,ade->ace', a, b) 

[[[210 224] 
    [165 176]] 

[[300 260] 
    [ 75 65]] 

[[240 420] 
    [144 252]] 

[[ 96 72] 
    [108 81]]] 

這是正確的形狀,但不是值。

我希望得到這樣的:

np.array([x.T.dot(y).tolist() for x, y in zip(a, b)]) 

[[[ 92 101] 
    [ 71 73]] 

[[ 88 104] 
    [ 23 22]] 

[[ 93 145] 
    [ 48 84]] 

[[ 12 34] 
    [ 33 21]]] 

回答

2

的矩陣乘法數額的,其中總和被在中軸線的乘積之和,因此指數b應該是兩個陣列相同的:(即改變adeabe):

In [40]: np.einsum('abc,abe->ace', a, b) 
Out[40]: 
array([[[ 92, 101], 
     [ 71, 73]], 

     [[ 88, 104], 
     [ 23, 22]], 

     [[ 93, 145], 
     [ 48, 84]], 

     [[ 12, 34], 
     [ 33, 21]]]) 

當輸入陣列具有缺少輸出數組中的索引標, 他們被獨立地總結。也就是說,

np.einsum('abc,ade->ace', a, b) 

相當於

In [44]: np.einsum('abc,ade->acebd', a, b).sum(axis=-1).sum(axis=-1) 
Out[44]: 
array([[[210, 224], 
     [165, 176]], 

     [[300, 260], 
     [ 75, 65]], 

     [[240, 420], 
     [144, 252]], 

     [[ 96, 72], 
     [108, 81]]]) 
+0

謝謝。這有助於澄清我錯過的東西。 – piRSquared

2

這裏有一個與np.matmul因爲我們需要的a第二軸推回至年底,因此它會得到sum-reduced對從第二軸b,同時保持他們的第一軸對準 -

np.matmul(a.swapaxes(1,2),b) 

示意性地提出:

啓動時:

a : M x N X R1 
b : M x N X R2 

與用於交換軸:

a : M x R1 X [N] 
b : M x [N] X R2 

括號內的軸得到sum-reduced,留給我們:

out : M x R1 X R2 

關於Python 3 。x,matmul@ operator負責 -

a.swapaxes(1,2) @ b 
+0

我不知道'np.matmul'。謝謝。 – piRSquared