一般來說,我不會建議試圖擊敗NumPy的。很少有人可以競爭(對於長陣列),更難以找到更快的實現。即使速度更快,速度可能也不會超過2倍。所以它很少值得。
但是我最近試圖自己做這樣的事情,所以我可以分享一些有趣的結果。
我自己並沒有想到這件事。我基於我的方法numbas (re-)implementation of np.median
。 他們可能知道他們在做什麼。
我最終什麼樣的主意是:
import numba as nb
import numpy as np
@nb.njit
def _partition(A, low, high):
"""copied from numba source code"""
mid = (low + high) >> 1
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
if A[high] < A[mid]:
A[high], A[mid] = A[mid], A[high]
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
pivot = A[mid]
A[high], A[mid] = A[mid], A[high]
i = low
for j in range(low, high):
if A[j] <= pivot:
A[i], A[j] = A[j], A[i]
i += 1
A[i], A[high] = A[high], A[i]
return i
@nb.njit
def _select_lowest(arry, k, low, high):
"""copied from numba source code, slightly changed"""
i = _partition(arry, low, high)
while i != k:
if i < k:
low = i + 1
i = _partition(arry, low, high)
else:
high = i - 1
i = _partition(arry, low, high)
return arry[:k]
@nb.njit
def _nlowest_inner(temp_arry, n, idx):
"""copied from numba source code, slightly changed"""
low = 0
high = n - 1
return _select_lowest(temp_arry, idx, low, high)
@nb.njit
def nlowest(a, idx):
"""copied from numba source code, slightly changed"""
temp_arry = a.flatten() # does a copy! :)
n = temp_arry.shape[0]
return _nlowest_inner(temp_arry, n, idx)
我做包含的定時之前的一些熱身電話。熱身是爲了讓編譯時間不包括在時序:
rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)
有一個(非常)慢的電腦我改變了元件的數量和重複的比特數。但結果似乎表明,我(當然,在numba開發者所做的那樣)擊敗NumPy的:
results = pd.DataFrame(
index=pd.Index([100, 500, 1000, 5000, 10000, 50000, 100000, 500000], name='Size'),
columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir', 'nlowest'], name='Method')
)
rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)
for i in results.index:
x = np.random.rand(i)
n = i // 2
for j in results.columns:
stmt = '{}(x, n)'.format(j)
setp = 'from __main__ import {}, x, n'.format(j)
results.set_value(i, j, timeit(stmt, setp, number=100))
print(results)
Method nsmall_np nsmall_pd nsmall_pir nlowest
Size
100 0.00343059 0.561372 0.00190855 0.000935566
500 0.00428461 1.79398 0.00326862 0.00187225
1000 0.00560669 3.36844 0.00432595 0.00364284
5000 0.0132515 0.305471 0.0142569 0.0108995
10000 0.0255161 0.340215 0.024847 0.0248285
50000 0.105937 0.543337 0.150277 0.118294
100000 0.2452 0.835571 0.333697 0.248473
500000 1.75214 3.50201 2.20235 1.44085
你需要改變多少代碼才能使用'njit'? – piRSquared
'_partition'函數被簡單地複製,'_select'函數只在最後一行('arry [:k]'而不是'arry [k]')中被改變。另外兩個函數被改變了一點:我改變了函數名稱,用一個新的'idx'參數替換了'mid'部分,並刪除了處理一個偶數長度數組中位數的部分。 'nlowest'函數最初是'median_impl'函數。我也用'@ njit'改變了'@ register_jitable',並且我不需要('想要')'@ overload'。說實話,這個評論可能需要花費更長時間才能改變numba源代碼。 :D – MSeifert
是的,看着你鏈接的代碼,看起來他們已經是'numba'的老練用戶了。感謝分享:-) – piRSquared