2017-09-23 78 views
1

給定形狀(2,3,4)的基本數組X,它可以解釋爲兩個3個元素集合,每個元素都是4維的,我想要抽樣從這個數組X中以下面的方式。 從2組中的每一組我想挑選2個子集,每個子​​集由長度爲3的二進制數組定義,其他子集將設置爲0.因此,採樣過程由形狀(2,2,3)的數組定義。這個抽樣的結果應該是形狀的(2,2,3,4)。從numpy中的多個集合中挑選和分配多個子集

下面是我需要的代碼,但我想知道可以使用numpy索引更有效地重寫。

import numpy as np 
np.random.seed(3) 

sets = np.random.randint(0, 10, [2, 3, 4]) 
subset_masks = np.random.randint(0, 2, [2, 2, 3]) 

print('Base set\n', sets, '\n') 
print('Subset masks\n', subset_masks, '\n') 

result = np.empty([2, 2, 3, 4]) 
for set_index in range(sets.shape[0]): 
    for subset_index, subset in enumerate(subset_masks[set_index]): 
     print('----') 
     picked_subset = subset.reshape(3, 1) * sets[set_index] 
     result[set_index][subset_index] = picked_subset 
     print('Picking subset ', subset, 'from set #', set_index) 
     print(picked_subset, '\n') 

輸出

Base set 
[[[8 9 3 8] 
    [8 0 5 3] 
    [9 9 5 7]] 

[[6 0 4 7] 
    [8 1 6 2] 
    [2 1 3 5]]] 

Subset masks 
[[[0 0 1] 
    [1 0 0]] 

[[1 0 1] 
    [0 1 1]]] 

---- 
Picking subset [0 0 1] from set # 0 
[[0 0 0 0] 
[0 0 0 0] 
[9 9 5 7]] 

---- 
Picking subset [1 0 0] from set # 0 
[[8 9 3 8] 
[0 0 0 0] 
[0 0 0 0]] 

---- 
Picking subset [1 0 1] from set # 1 
[[6 0 4 7] 
[0 0 0 0] 
[2 1 3 5]] 

---- 
Picking subset [0 1 1] from set # 1 
[[0 0 0 0] 
[8 1 6 2] 
[2 1 3 5]] 

回答

1

它們中的每通過添加新的軸爲沿着最後一個subset_masks並延伸到4Dsets作爲第二軸。爲了添加這些新的軸,我們可以使用None/np.newaxis。然後,槓桿NumPy broadcasting進行逐元素相乘,像這樣 -

subset_masks[...,None]*sets[:,None] 

只是爲了踢或許,我們還可以使用np.einsum -

np.einsum('ijk,ilj->iljk',sets,subset_masks) 
+0

這兩個選項的工作。謝謝! – eclique