2017-06-02 84 views

回答

1

一般來說,我不會建議試圖擊敗NumPy的。很少有人可以競爭(對於長陣列),更難以找到更快的實現。即使速度更快,速度可能也不會超過2倍。所以它很少值得。

但是我最近試圖自己做這樣的事情,所以我可以分享一些有趣的結果。

我自己並沒有想到這件事。我基於我的方法numbas (re-)implementation of np.median他們可能知道他們在做什麼。

我最終什麼樣的主意是:

import numba as nb 
import numpy as np 

@nb.njit 
def _partition(A, low, high): 
    """copied from numba source code""" 
    mid = (low + high) >> 1 
    if A[mid] < A[low]: 
     A[low], A[mid] = A[mid], A[low] 
    if A[high] < A[mid]: 
     A[high], A[mid] = A[mid], A[high] 
     if A[mid] < A[low]: 
      A[low], A[mid] = A[mid], A[low] 
    pivot = A[mid] 

    A[high], A[mid] = A[mid], A[high] 

    i = low 
    for j in range(low, high): 
     if A[j] <= pivot: 
      A[i], A[j] = A[j], A[i] 
      i += 1 

    A[i], A[high] = A[high], A[i] 
    return i 

@nb.njit 
def _select_lowest(arry, k, low, high): 
    """copied from numba source code, slightly changed""" 
    i = _partition(arry, low, high) 
    while i != k: 
     if i < k: 
      low = i + 1 
      i = _partition(arry, low, high) 
     else: 
      high = i - 1 
      i = _partition(arry, low, high) 
    return arry[:k] 

@nb.njit 
def _nlowest_inner(temp_arry, n, idx): 
    """copied from numba source code, slightly changed""" 
    low = 0 
    high = n - 1 
    return _select_lowest(temp_arry, idx, low, high) 

@nb.njit 
def nlowest(a, idx): 
    """copied from numba source code, slightly changed""" 
    temp_arry = a.flatten() # does a copy! :) 
    n = temp_arry.shape[0] 
    return _nlowest_inner(temp_arry, n, idx) 

我做包含的定時之前的一些熱身電話。熱身是爲了讓編譯時間不包括在時序:

rselect(np.random.rand(10), 5) 
nlowest(np.random.rand(10), 5) 

有一個(非常)慢的電腦我改變了元件的數量和重複的比特數。但結果似乎表明,我(當然,在numba開發者所做的那樣)擊敗NumPy的:

results = pd.DataFrame(
    index=pd.Index([100, 500, 1000, 5000, 10000, 50000, 100000, 500000], name='Size'), 
    columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir', 'nlowest'], name='Method') 
) 

rselect(np.random.rand(10), 5) 
nlowest(np.random.rand(10), 5) 

for i in results.index: 
    x = np.random.rand(i) 
    n = i // 2 
    for j in results.columns: 
     stmt = '{}(x, n)'.format(j) 
     setp = 'from __main__ import {}, x, n'.format(j) 
     results.set_value(i, j, timeit(stmt, setp, number=100)) 

print(results) 

Method nsmall_np nsmall_pd nsmall_pir  nlowest 
Size             
100  0.00343059 0.561372 0.00190855 0.000935566 
500  0.00428461 1.79398 0.00326862 0.00187225 
1000 0.00560669 3.36844 0.00432595 0.00364284 
5000  0.0132515 0.305471 0.0142569 0.0108995 
10000 0.0255161 0.340215 0.024847 0.0248285 
50000  0.105937 0.543337 0.150277  0.118294 
100000  0.2452 0.835571 0.333697  0.248473 
500000  1.75214 3.50201  2.20235  1.44085 

enter image description here

+0

你需要改變多少代碼才能使用'njit'? – piRSquared

+1

'_partition'函數被簡單地複製,'_select'函數只在最後一行('arry [:k]'而不是'arry [k]')中被改變。另外兩個函數被改變了一點:我改變了函數名稱,用一個新的'idx'參數替換了'mid'部分,並刪除了處理一個偶數長度數組中位數的部分。 'nlowest'函數最初是'median_impl'函數。我也用'@ njit'改變了'@ register_jitable',並且我不需要('想要')'@ overload'。說實話,這個評論可能需要花費更長時間才能改變numba源代碼。 :D – MSeifert

+0

是的,看着你鏈接的代碼,看起來他們已經是'numba'的老練用戶了。感謝分享:-) – piRSquared

2

更新
@ user2357112指出在我的功能在現場操縱的評論中。轉過來就是我的表現提升來自的地方。所以最後,我們與quickselectnumba的粗略實現具有非常相似的性能。仍然沒有什麼可以打噴嚏,但不是我所希望的。


正如我在質詢時說,我與numba瞎搞,想和大家分享我已經找到。

請注意,我已導入njit而不是jit。這是一個裝飾器,可以自動防止本身回退到本地python對象。意思是說,當它加快速度時,它只會使用它實際上可以加速的東西。這反過來意味着我的功能失敗了很多,而我找出什麼是允許的,什麼是不允許的。

到目前爲止,這是我的看法,與numba小號jitnjit寫東西是挑剔和困難,但那種值得的,當你看到一個不俗的表現回報。

這是我的快速和骯髒的quickselect功能

import numpy as np 
from numba import njit 
import pandas as pd 
import numexpr as ne 

@njit 
def rselect(a, k): 
    n = len(a) 
    if n <= 1: 
     return a 
    elif k > n: 
     return a 
    else: 
     p = np.random.randint(n) 
     pivot = a[p] 
     a[0], a[p] = a[p], a[0] 
     i = j = 1 
     while j < n: 
      if a[j] < pivot: 
       a[j], a[i] = a[i], a[j] 
       i += 1 
      j += 1 
     a[i-1], a[0] = a[0], a[i-1] 
     if i - 1 <= k <= i: 
      return a[:k] 
     elif k > i: 
      return np.concatenate((a[:i], rselect(a[i:], k - i))) 
     else: 
      return rselect(a[:i-1], k) 

你會發現它返回相同的元素以問題的方法。

rselect(x, 5) 

array([2, 1, 0, 3, 4]) 

什麼速度?

def nsmall_np(x, n): 
    return np.partition(x, n)[:n] 

def nsmall_pd(x, n): 
    pd.Series(x).nsmallest().values 

def nsmall_pir(x, n): 
    return rselect(x.copy(), n) 


from timeit import timeit 


results = pd.DataFrame(
    index=pd.Index([100, 1000, 3000, 6000, 10000, 100000, 1000000], name='Size'), 
    columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir'], name='Method') 
) 

for i in results.index: 
    x = np.random.rand(i) 
    n = i // 2 
    for j in results.columns: 
     stmt = '{}(x, n)'.format(j) 
     setp = 'from __main__ import {}, x, n'.format(j) 
     results.set_value(
      i, j, timeit(stmt, setp, number=1000) 
     ) 

results 

Method nsmall_np nsmall_pd nsmall_pir 
Size          
100  0.003873 0.336693 0.002941 
1000  0.007683 1.170193 0.011460 
3000  0.016083 0.309765 0.029628 
6000  0.050026 0.346420 0.059591 
10000  0.106036 0.435710 0.092076 
100000 1.064301 2.073206 0.936986 
1000000 11.864195 27.447762 12.755983 

results.plot(title='Selection Speed', colormap='jet', figsize=(10, 6)) 

[1]: https://i.stack.imgur.com/hKo2o png格式

+2

你似乎變異的輸入,而'numpy.partition'進行復印。你是否定時執行了'ndarray.partition'方法的性能? – user2357112

+0

@ user2357112好眼睛...看着它 – piRSquared

+0

@ user2357112和** PooF **有所有的性能好處。謝謝......看到亂搞已經教會了我一些東西。 – piRSquared

相關問題