2015-07-12 61 views
0

欲索引通過多維陣列這樣numpy的索引:上的可變軸

a = range(12).reshape(3, 2, 2) 
def fun(axis, state): 
    # if axis=0 
    return a[state, :, :] 
    # if axis=1 it should return a[:, state, :] 

樣本輸出:

fun(0, 1) 
array([[4, 5],         
     [6, 7]]) 

fun(1, 1) 
array([[2, 3], 
     [6, 7], 
     [10, 11]]) 

總之我希望接受軸作爲參數。

我想不出有辦法做到這一點。任何可能的解

+1

另見'dynamic axis indexing':http://stackoverflow.com/questions/31094641/dynamic-axis-indexing-of-numpy-ndarray/31094758#31094758;一些'numpy'函數使用'transpose'('rollaxis'),另一些則構造一個索引元組。 – hpaulj

回答

1

可以採取與指定軸線的陣列的視圖移到前面使用numpy.rollaxis

def fun(a, axis, state): 
    return numpy.rollaxis(a, axis)[state] 

演示:

>>> a = numpy.arange(12).reshape([3, 2, 2]) 
>>> def fun(a, axis, state): 
...  return numpy.rollaxis(a, axis)[state] 
... 
>>> fun(a, 0, 1) 
array([[4, 5], 
     [6, 7]]) 
>>> fun(a, 1, 1) 
array([[ 2, 3], 
     [ 6, 7], 
     [10, 11]]) 

numpy.rollaxis還支持移動軸到另一個位置時,雖然它解釋爲什麼這是奇怪的論據。