2016-07-29 95 views
0

我想找到兩組數組之間的最短距離。 x數組是相同的,只是包含整數。以下是我正在嘗試執行的一個示例:查找兩個數組之間最短距離的有效方法?

import numpy as np 
x1 = x2 = np.linspace(-1000, 1000, 2001) 
y1 = (lambda x, a, b: a*x + b)(x1, 2, 1) 
y2 = (lambda x, a, b: a*(x-2)**2 + b)(x2, 2, 10) 

def dis(x1, y1, x2, y2): 
    return sqrt((y2-y1)**2+(x2-x1)**2) 

min_distance = np.inf 
for a, b in zip(x1, y1): 
    for c, d in zip(x2, y2): 
     if dis(a, b, c, d) < min_distance: 
      min_distance = dis(a, b, c, d) 

>>> min_distance 
2.2360679774997898 

此解決方案有效,但問題在於運行時。如果x的長度大約爲10,000,那麼解決方案就不可行,因爲程序運行時間爲O(n^2)。現在,我嘗試了一些近似值來加快程序的運行速度:

for a, b in zip(x1, y1): 
    cut = (x2 > a-20)*(x2 < a+20) 
    for c, d in zip(x2, y2): 
     if dis(a, b, c, d) < min_distance: 
      min_distance = dis(a, b, c, d) 

但是程序仍然比我想要的要長。現在,從我的理解來看,循環遍歷數組通常是低效的,所以我相信還有改進的餘地。關於如何加快此計劃的任何想法?

+0

可能重複的[如何用numpy計算歐幾里德距離?](http://stackoverflow.com/questions/1401712/how-can-the-euclidean-distance-be-calculated-with-numpy) – tmthydvnprt

+0

[有效檢查Python中大量對象的歐幾里得距離]的可能副本(http://stackoverflow.com/q/29885212/1461210) –

回答

0

這是一個難題,如果您願意接受近似值,這可能會有所幫助。我會檢查一下像Spottify的annoy

1

你的問題也可以表示爲2d碰撞檢測,所以quadtree可能會有所幫助。插入和查詢都在O(log n)時間內運行,因此整個搜索將在O(n log n)中運行。

還有一個建議,由於sqrt是單調的,因此可以比較距離的平方而不是距離本身,這將爲您節省n^2平方根計算。

1

scipy具有cdist function其計算所有點對間的距離:

from scipy.spatial.distance import cdist 
import numpy as np 

x1 = x2 = np.linspace(-1000, 1000, 2001) 
y1 = (lambda x, a, b: a*x + b)(x1, 2, 1) 
y2 = (lambda x, a, b: a*(x-2)**2 + b)(x2, 2, 10) 

R1 = np.vstack((x1,y1)).T 
R2 = np.vstack((x2,y2)).T 

dists = cdist(R1,R2) # find all mutual distances 

print (dists.min()) 
# output: 2.2360679774997898 

這將運行比原來的for循環快250倍以上。

相關問題