2013-02-19 56 views
3

我想弄清楚一個更好的方法來檢查兩個2D數組是否包含相同的行。以下面的案例爲例:最快的方法來檢查兩個陣列是否有相同的行

>>> a 
array([[0, 1, 2], 
     [3, 4, 5], 
     [6, 7, 8]]) 
>>> b 
array([[6, 7, 8], 
     [3, 4, 5], 
     [0, 1, 2]]) 

在這種情況下b=a[::-1]。檢查兩行是否相等:

>>>a=a[np.lexsort((a[:,0],a[:,1],a[:,2]))] 
>>>b=b[np.lexsort((b[:,0],b[:,1],b[:,2]))] 
>>> np.all(a-b==0) 
True 

這很好,而且速度相當快。但是問題來時,左右兩排「閉」:

array([[-1.57839867 2.355354 -1.4225235 ], 
     [-0.94728367 0.   -1.4225235 ], 
     [-1.57839867 -2.355354 -1.4225215 ]]) <---note ends in 215 not 235 
array([[-1.57839867 -2.355354 -1.4225225 ], 
     [-1.57839867 2.355354 -1.4225225 ], 
     [-0.94728367 0.   -1.4225225 ]]) 

在1E-5的容差這兩個數組是按行相等,但lexsort會告訴你,否則。這可以通過不同的排序順序來解決,但我想要一個更一般的情況。

我用的想法醞釀:

a=a.reshape(-1,1,3) 
>>> a-b 
array([[[-6, -6, -6], 
     [-3, -3, -3], 
     [ 0, 0, 0]], 

     [[-3, -3, -3], 
     [ 0, 0, 0], 
     [ 3, 3, 3]], 

     [[ 0, 0, 0], 
     [ 3, 3, 3], 
     [ 6, 6, 6]]]) 
>>> np.all(np.around(a-b,5)==0,axis=2) 
array([[False, False, True], 
     [False, True, False], 
     [ True, False, False]], dtype=bool) 
>>>np.all(np.any(np.all(np.around(a-b,5)==0,axis=2),axis=1)) 
True 

這不告訴你,如果陣列是按行等於只是如果b所有點都接近a值。行數可以是幾百,我需要做很多。有任何想法嗎?

+1

我會拋出'scipy.spatial.cKDTree'(可能KDTree,取決於scipy版本和用法),對於一種可能更直接的方法。 – seberg 2013-02-19 23:16:42

+0

這正是我正在尋找的。知道必須有更好的方法。 – Daniel 2013-02-19 23:54:42

回答

1

你最後的代碼不會做你認爲它在做什麼。它告訴你的是b中的每一行是否接近a中的一行。如果將axis更改爲用於np.anynp.all的外部調用,則可以檢查a中的每一行是否接近b中的某一行。如果b中的每一行都接近a中的一行,並且a中的每一行都接近b中的一行,則這些集合是相等的。大概不會計算非常有效的,但可能在numpy的爲中等大小的數組非常快:

def same_rows(a, b, tol=5) : 
    rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1) 
    return (np.all(np.any(rows_close, axis=-1), axis=-1) and 
      np.all(np.any(rows_close, axis=0), axis=0)) 

>>> rows, cols = 5, 3 
>>> a = np.arange(rows * cols).reshape(rows, cols) 
>>> b = np.arange(rows) 
>>> np.random.shuffle(b) 
>>> b = a[b] 
>>> a 
array([[ 0, 1, 2], 
     [ 3, 4, 5], 
     [ 6, 7, 8], 
     [ 9, 10, 11], 
     [12, 13, 14]]) 
>>> b 
array([[ 9, 10, 11], 
     [ 3, 4, 5], 
     [ 0, 1, 2], 
     [ 6, 7, 8], 
     [12, 13, 14]]) 
>>> same_rows(a, b) 
True 
>>> b[0] = b[1] 
>>> b 
array([[ 3, 4, 5], 
     [ 3, 4, 5], 
     [ 0, 1, 2], 
     [ 6, 7, 8], 
     [12, 13, 14]]) 
>>> same_rows(a, b) # not all rows in a are close to a row in b 
False 

而對於沒有太大的陣列,性能是合理的,即使它是有打造(rows, rows, cols)數組:

In [2]: rows, cols = 1000, 10 

In [3]: a = np.arange(rows * cols).reshape(rows, cols) 

In [4]: b = np.arange(rows) 

In [5]: np.random.shuffle(b) 

In [6]: b = a[b] 

In [7]: %timeit same_rows(a, b) 
10 loops, best of 3: 103 ms per loop 
+0

我確實提到過,這是我發佈的代碼的一個問題。這基本上是我最終寫了幾個額外的參數。我在距離公式中加入了一個更好的概念,讓我們更好地瞭解一個點是多麼接近,並使用lexsort方法來大幅減少需要傳遞給這種類型的檢查的行數。如果明天沒有人提出更好的想法,我會檢查你的答案。 – Daniel 2013-02-19 22:35:36

相關問題