2015-01-16 38 views
3

我有一些物理模擬代碼,用python編寫並使用numpy/scipy。分析代碼顯示38%的CPU時間花費在一個雙重嵌套的循環中 - 這看起來過多,所以我一直試圖減少它。在numpy中創建索引數組 - 消除循環的雙重

該循環的目標是創建一個索引數組,顯示2D數組的元素與一維數組的哪些元素相等。

indices[i,j] = where(1D_array == 2D_array[i,j]) 

舉個例子,如果1D_array = [7.2, 2.5, 3.9]

2D_array = [[7.2, 2.5] 
      [3.9, 7.2]] 

我們應該有

indices = [[0, 1] 
      [2, 0]] 

我現在有這個實現

for i in range(ni): 
    for j in range(nj): 
     out[i, j] = (1D_array - 2D_array[i, j]).argmin() 

的因爲我正在處理浮點數,所以需要,所以平等不一定是確切的。我知道1D數組中的每個數字都是唯一的,並且2D數組中的每個元素都有匹配,所以這種方法可以給出正確的結果。

有沒有什麼方法可以消除雙循環?

我需要的索引陣列以執行下列操作:

f = complex_function(1D_array) 
output = f[indices] 

這比替代速度更快,因爲2D陣列的大小爲N×N的具有1×N個爲比較1D陣列和2D陣列具有許多重複值。如果任何人都可以提出在相同的輸出到達,而無需通過索引陣列中會以不同的方式,這也可能是一個解決辦法

+0

'1D_array'總是排序嗎? –

+0

@AshwiniChaudhary,不,不是。事實上,它永遠不會。我將編輯該示例以刪除該示例。 – Sten

+0

爲了達到這個目的,我認爲1D_array中的條目不會重複。爲什麼不用1D_array創建一個字典,其中的值作爲鍵和索引作爲值?即'{0:7.2,1:2.5,2:3.9}'那麼你只需要將字典應用到數組。 – Roberto

回答

1

的字典法,有些人有建議可能會起作用,但它要求您事先知道目標數組(2d數組)中的每個元素在搜索數組(您的1d數組)中都有完全匹配。即使原則上這應該是真的,但仍然需要處理浮點精度問題,例如試試.1 * 3 == .3

另一種方法是使用numpy的searchsorted函數。 searchsorted需要一個排序的1d搜索數組,然後任何traget數組然後爲目標數組中的每個項目在搜索數組中找到最接近的元素。我已經根據你的情況調整了這個answer,看看它是如何描述find_closest函數的工作原理。

import numpy as np 

def find_closest(A, target): 
    order = A.argsort() 
    A = A[order] 

    idx = A.searchsorted(target) 
    idx = np.clip(idx, 1, len(A)-1) 
    left = A[idx-1] 
    right = A[idx] 
    idx -= target - left < right - target 
    return order[idx] 

array1d = np.array([7.2, 2.5, 3.9]) 
array2d = np.array([[7.2, 2.5], 
        [3.9, 7.2]]) 

indices = find_closest(array1d, array2d) 
print(indices) 
# [[0 1] 
# [2 0]] 
2

在純Python中,你可以在O(N)時間做使用字典這個唯一一次點球是怎麼回事參與Python的循環:

>>> arr1 = np.array([7.2, 2.5, 3.9]) 
>>> arr2 = np.array([[7.2, 2.5], [3.9, 7.2]]) 
>>> indices = dict(np.hstack((arr1[:, None], np.arange(3)[:, None]))) 
>>> np.fromiter((indices[item] for item in arr2.ravel()), dtype=arr2.dtype).reshape(arr2.shape) 
array([[ 0., 1.], 
     [ 2., 0.]]) 
1

爲了擺脫兩個Python for循環,就可以通過添加新的座標軸的陣列做所有的「一次」平等的比較(使它們與每個broadcastable其他)。

請記住,這會產生一個包含len(arr1)*len(arr2)值的新數組。如果這是一個非常大的數字,這種方法可能是不可行的,這取決於你的記憶力的限制。否則,它應該是相當快:

>>> (arr1[:,np.newaxis] == arr2[:,np.newaxis]).argmax(axis=1) 
array([[0, 1], 
     [2, 0]], dtype=int32) 

如果你需要得到最接近匹配值的指數arr1而是使用:

np.abs(arr1[:,np.newaxis] - arr2[:,np.newaxis]).argmin(axis=1)