2017-08-15 150 views
1

給定一個3維numpy數組,如何查找前n個最小值的索引?最小值的指數可以發現:查找三維numpy數組中的n個最小值的索引

i,j,k = np.where(my_array == my_array.min()) 
+2

的可能的複製[獲取N個最高值的索引在ndarray](https://stackoverflow.com/questions/26603747/get-the-indices-of-n-highest-values-in-an-ndarray) – stamaimer

+0

發佈的解決方案是否適合你? – Divakar

+0

是的,它確實有效。非常感謝夥伴 –

回答

4

下面是通用正變暗和通用n個最小號的一種方法 -

def smallestN_indices(a, N): 
    idx = a.ravel().argsort()[:N] 
    return np.stack(np.unravel_index(idx, a.shape)).T 

的的2D輸出數組將持有的每一行索引與最小陣列號之一對應的元組。

我們也可以使用argpartition,但可能無法維持訂單。所以,我們需要argsort還有多一點額外的工作 -

def smallestN_indices_argparitition(a, N, maintain_order=False): 
    idx = np.argpartition(a.ravel(),N)[:N] 
    if maintain_order: 
     idx = idx[a.ravel()[idx].argsort()] 
    return np.stack(np.unravel_index(idx, a.shape)).T 

採樣運行 -

In [141]: np.random.seed(1234) 
    ...: a = np.random.randint(111,999,(2,5,4,3)) 
    ...: 

In [142]: smallestN_indices(a, N=3) 
Out[142]: 
array([[0, 3, 2, 0], 
     [1, 2, 3, 0], 
     [1, 2, 2, 1]]) 

In [143]: smallestN_indices_argparitition(a, N=3) 
Out[143]: 
array([[1, 2, 3, 0], 
     [0, 3, 2, 0], 
     [1, 2, 2, 1]]) 

In [144]: smallestN_indices_argparitition(a, N=3, maintain_order=True) 
Out[144]: 
array([[0, 3, 2, 0], 
     [1, 2, 3, 0], 
     [1, 2, 2, 1]]) 

運行測試 -

In [145]: a = np.random.randint(111,999,(20,50,40,30)) 

In [146]: %timeit smallestN_indices(a, N=3) 
    ...: %timeit smallestN_indices_argparitition(a, N=3) 
    ...: %timeit smallestN_indices_argparitition(a, N=3, maintain_order=True) 
    ...: 
10 loops, best of 3: 97.6 ms per loop 
100 loops, best of 3: 8.32 ms per loop 
100 loops, best of 3: 8.34 ms per loop 
+0

工作過,謝謝! –

相關問題