2017-10-21 131 views
1

我想收集指定軸中指定索引的元素,如下所示。如何收集numpy中特定索引的元素?

x = [[1,2,3], [4,5,6]] 
index = [[2,1], [0, 1]] 
x[:, index] = [[3, 2], [4, 5]] 

這實質上是在pytorch中的收集操作,但正如你所知道的,這在numpy中是不可行的。我想知道在numpy中是否有這樣的「聚集」操作?

回答

1
>>> x = np.array([[1,2,3], [4,5,6]]) 
>>> index = np.array([[2,1], [0, 1]]) 
>>> x_axis_index=np.tile(np.arange(len(x)), (index.shape[1],1)).transpose() 
>>> print x_axis_index 
[[0 0] 
[1 1]] 
>>> print x[x_axis_index,index] 
[[3 2] 
[4 5]] 
+0

注意還可以使用'np.arange(len(x))'不確定是否np.range是可取的! –

+0

注意:range(x.shape [0])和range(len(x))給出了一個列表,而np.arange(len(x))和np.arange(x.shape [0])給出了一個數組。數組和列表都有相同的元素。 – Sam17

+0

我想我的問題/陳述更多的是關於性能的問題,在一個非常大的數組中,我懷疑np.range的索引會更快(len vs shape肯定無關緊要)。 –