2016-12-30 91 views
4

考慮陣列x和delta變量d計數在numpy的陣列的元素數量是每一個其他元件

np.random.seed([3,1415]) 
x = np.random.randint(100, size=10) 
d = 10 

對於x各元素的增量內,我要計算有多少其他元件中的每個是在德爾塔內d距離遠。

所以X看起來像

print(x) 

[11 98 74 90 15 55 13 11 13 26] 

結果應該

[5 2 1 2 5 1 5 5 5 1] 

我已經試過
策略:

  • 使用廣播採取外差
  • 外差的絕對值
  • 金額多少不超過閾值

(np.abs(x[:, None] - x) <= d).sum(-1) 

[5 2 1 2 5 1 5 5 5 1] 

這個偉大的工程。但是,它不會縮放。那個外面的差別是O(n^2)時間。我怎樣才能得到與二次時間不成比例的相同解決方案?

回答

4

上市在這個職位是兩個基於searchsorted strategyOP's answer post的變種

def pir3(a,d): # Short & less efficient 
    sidx = a.argsort() 
    p1 = a.searchsorted(a+d,'right',sorter=sidx) 
    p2 = a.searchsorted(a-d,sorter=sidx) 
    return p1 - p2 

def pir4(a, d): # Long & more efficient 
    s = a.argsort() 

    y = np.empty(s.size,dtype=np.int64) 
    y[s] = np.arange(s.size) 

    a_ = a[s] 
    return (
     a_.searchsorted(a_ + d, 'right') 
     - a_.searchsorted(a_ - d) 
    )[y] 

的更有效的方法導出有效主意,this post得到s.argsort()

運行測試 -

In [155]: # Inputs 
    ...: a = np.random.randint(0,1000000,(10000)) 
    ...: d = 10 


In [156]: %timeit pir2(a,d) #@ piRSquared's post solution 
    ...: %timeit pir3(a,d) 
    ...: %timeit pir4(a,d) 
    ...: 
100 loops, best of 3: 2.43 ms per loop 
100 loops, best of 3: 4.44 ms per loop 
1000 loops, best of 3: 1.66 ms per loop 
+0

非常感謝。我已經更新了測試結果以包含這些變體並擴展了數組的大小。 – piRSquared

+0

@piRSquared這些是有道理的。對於較小的數組,'pir4'中用於創建獲取's.argsort()'的範圍數組的開銷使其不如簡單排序更有價值。對於您爲了計算這個計數問題而使用'searchsorted'好心想! – Divakar

1

戰略

  • 由於x不一定排序,我們將對其進行排序,並通過argsort跟蹤排序排列,所以我們可以扭轉排列。
  • 我們將使用np.searchsortedxx - d找到x的值開始超過x - d的起始位置。
  • 再做一次在另一邊,除了我們不得不使用np.searchsorted參數side='right'和使用x + d
  • 採取正確的區別和左searchsorts計算元素是每個元素的+/- d內的數
  • 使用argsort扭轉排序排列

限定方法中問題呈現爲pir1

def pir1(a, d): 
    return (np.abs(a[:, None] - a) <= d).sum(-1) 

我們將定義一個新的功能pir2

def pir2(a, d): 
    s = x.argsort() 
    a_ = a[s] 
    return (
     a_.searchsorted(a_ + d, 'right') 
     - a_.searchsorted(a_ - d) 
    )[s.argsort()] 

演示

pir1(x, d) 

[5 2 1 2 5 1 5 5 5 1]  

pir1(x, d) 

[5 2 1 2 5 1 5 5 5 1]  

時機
pir2是明顯的贏家!

代碼

功能

def pir1(a, d): 
    return (np.abs(a[:, None] - a) <= d).sum(-1) 

def pir2(a, d): 
    s = x.argsort() 
    a_ = a[s] 
    return (
     a_.searchsorted(a_ + d, 'right') 
     - a_.searchsorted(a_ - d) 
    )[s.argsort()] 

####################### 
# From Divakar's post # 
####################### 
def pir3(a,d): # Short & less efficient 
    sidx = a.argsort() 
    p1 = a.searchsorted(a+d,'right',sorter=sidx) 
    p2 = a.searchsorted(a-d,sorter=sidx) 
    return p1 - p2 

def pir4(a, d): # Long & more efficient 
    s = a.argsort() 

    y = np.empty(s.size,dtype=np.int64) 
    y[s] = np.arange(s.size) 

    a_ = a[s] 
    return (
     a_.searchsorted(a_ + d, 'right') 
     - a_.searchsorted(a_ - d) 
    )[y] 

測試

from timeit import timeit 

results = pd.DataFrame(
    index=np.arange(1, 50), 
    columns=['pir%s' %i for i in range(1, 5)]) 

for i in results.index: 
    np.random.seed([3,1415]) 
    x = np.random.randint(1000000, size=i) 
    for j in results.columns: 
     setup = 'from __main__ import x, {}'.format(j) 
     results.loc[i, j] = timeit('{}(x, 10)'.format(j), setup=setup, number=10000) 

results.plot() 

enter image description here


延伸到了更大的陣列
擺脫pir1

from timeit import timeit 

results = pd.DataFrame(
    index=np.arange(1, 11) * 1000, 
    columns=['pir%s' %i for i in range(2, 5)]) 

for i in results.index: 
    np.random.seed([3,1415]) 
    x = np.random.randint(1000000, size=i) 
    for j in results.columns: 
     setup = 'from __main__ import x, {}'.format(j) 
     results.loc[i, j] = timeit('{}(x, 10)'.format(j), setup=setup, number=100) 

results.insert(0, 'pir1', 0) 

results.plot() 

enter image description here

相關問題