我是Numba的新手,並且正在努力加速一些已經證明對於numpy來說過於笨拙的計算。我在下面給出的例子比較了一個包含我的計算子集的函數,該函數使用向量化/ numpy和函數的numba版本,後者也通過評論@autojit裝飾器測試爲純python。我的numba代碼可以比numpy快嗎
我發現numba和numpy版本相對於純python提供了類似的加速,這兩者都是大約10倍速度改進的因素。 numpy版本實際上比我的numba函數稍快,但由於這種計算的4D性質,當numpy函數中的數組的大小比這個玩具示例大得多時,我會很快耗盡內存。
這個加速是好的,但我經常看到從純python移動到numba時,web上的加速度超過100x。
我想知道在nopython模式下移動到numba時是否會有普遍的預期速度增加。我還想知道是否有任何我的numba化功能的組件會限制進一步的速度增加。
import numpy as np
from timeit import default_timer as timer
from numba import autojit
import math
def vecRadCalcs(slope, skyz, solz, skya, sola):
nloc = len(slope)
ntime = len(solz)
[lenz, lena] = skyz.shape
asolz = np.tile(np.reshape(solz,[ntime,1,1,1]),[1,nloc,lenz,lena])
asola = np.tile(np.reshape(sola,[ntime,1,1,1]),[1,nloc,lenz,lena])
askyz = np.tile(np.reshape(skyz,[1,1,lenz,lena]),[ntime,nloc,1,1])
askya = np.tile(np.reshape(skya,[1,1,lenz,lena]),[ntime,nloc,1,1])
phi1 = np.cos(asolz)*np.cos(askyz)
phi2 = np.sin(asolz)*np.sin(askyz)*np.cos(askya- asola)
phi12 = phi1 + phi2
phi12[phi12> 1.0] = 1.0
phi = np.arccos(phi12)
return(phi)
@autojit
def RadCalcs(slope, skyz, solz, skya, sola, phi):
nloc = len(slope)
ntime = len(solz)
pop = 0.0
[lenz, lena] = skyz.shape
for iiT in range(ntime):
asolz = solz[iiT]
asola = sola[iiT]
for iL in range(nloc):
for iz in range(lenz):
for ia in range(lena):
askyz = skyz[iz,ia]
askya = skya[iz,ia]
phi1 = math.cos(asolz)*math.cos(askyz)
phi2 = math.sin(asolz)*math.sin(askyz)*math.cos(askya- asola)
phi12 = phi1 + phi2
if phi12 > 1.0:
phi12 = 1.0
phi[iz,ia] = math.acos(phi12)
pop = pop + 1
return(pop)
zenith_cells = 90
azim_cells = 360
nloc = 10 # nominallly ~ 700
ntim = 10 # nominallly ~ 200000
slope = np.random.rand(nloc) * 10.0
solz = np.random.rand(ntim) *np.pi/2.0
sola = np.random.rand(ntim) * 1.0*np.pi
base = np.ones([zenith_cells,azim_cells])
skya = np.deg2rad(np.cumsum(base,axis=1))
skyz = np.deg2rad(np.cumsum(base,axis=0)*90/zenith_cells)
phi = np.zeros(skyz.shape)
start = timer()
outcalc = RadCalcs(slope, skyz, solz, skya, sola, phi)
stop = timer()
outcalc2 = vecRadCalcs(slope, skyz, solz, skya, sola)
stopvec = timer()
print(outcalc)
print(stop-start)
print(stopvec-stop)