2015-11-01 35 views
4

我想使用numba加快這一功能:Numba:細胞瓦爾不支持

from numba import jit 
@jit 
def rownowaga_numba(u, v): 
    wymiar_x = len(u) 
    wymiar_y = len(u[1]) 
    f = [[[0 for j in range(wymiar_y)] for i in range(wymiar_x)] for k in range(9)] 
    cx = [0., 1., 0., -1., 0., 1., -1., -1., 1.] 
    cy = [0., 0., 1., 0., -1., 1., 1., -1., -1.] 
    w = [4./9, 1./9, 1./9, 1./9, 1./9, 1./36, 1./36, 1./36, 1./36] 
    for i in range(wymiar_x): 
     for j in range (wymiar_y): 
      for k in range(9): 
       up = u[i][j] 
       vp = v[i][j] 
       udot = (up**2 + vp**2) 
       cu = up*cx[k] + vp*cy[k] 
       f[k][i][j] = w[k] + w[k]*(3.0*cu + 4.5*cu**2 - 1.5*udot) 
    return f 

如果我有這樣的數據測試:

import timeit 
import math as m 

u = [[m.sin(i) + m.cos(j) for j in range(40)] for i in range(1000)] 
y = [[m.sin(i) + m.cos(j) for j in range(40)] for i in range(1000)] 

t0 = timeit.default_timer() 

for i in range (10): 
    f = rownowaga_pypy(u,y) 

dt = timeit.default_timer() - t0 
print('loop time:', dt) 

和IM收到此錯誤:

Traceback (most recent call last): 
    File "C:\Users\Ricevind\Desktop\PyPy\Skrypty\Rownowaga.py", line 29, in <module> 
    f = rownowaga_pypy(u,y) 
    File "C:\pyzo2014a\lib\site-packages\numba\dispatcher.py", line 171, in _compile_for_args 
    return self.compile(sig) 
    File "C:\pyzo2014a\lib\site-packages\numba\dispatcher.py", line 348, in compile 
    flags=flags, locals=self.locals) 
    File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 637, in compile_extra 
    return pipeline.compile_extra(func) 
    File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 356, in compile_extra 
    raise e 
    File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 351, in compile_extra 
    bc = self.extract_bytecode(func) 
    File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 343, in extract_bytecode 
    bc = bytecode.ByteCode(func=self.func) 
    File "C:\pyzo2014a\lib\site-packages\numba\bytecode.py", line 343, in __init__ 
    raise NotImplementedError("cell vars are not supported") 
NotImplementedError: cell vars are not supported 

我在的意義最感興趣的「細胞瓦爾不支持」作爲谷歌返回沒有意義的結果。

回答

4

Numba目前在列表的嵌套列表中不能很好地工作(至少從v0.21開始)。我相信這就是'單元差異'錯誤所指的,但我不是100%肯定的。下面,我將所有的都以numpy的陣列,以使代碼由numba進行優化:

import numpy as np 
import numba as nb 
import math 

def rownowaga(u, v): 
    wymiar_x = len(u) 
    wymiar_y = len(u[1]) 
    f = [[[0 for j in range(wymiar_y)] for i in range(wymiar_x)] for k in range(9)] 
    cx = [0., 1., 0., -1., 0., 1., -1., -1., 1.] 
    cy = [0., 0., 1., 0., -1., 1., 1., -1., -1.] 
    w = [4./9, 1./9, 1./9, 1./9, 1./9, 1./36, 1./36, 1./36, 1./36] 
    for i in range(wymiar_x): 
     for j in range (wymiar_y): 
      for k in range(9): 
       up = u[i][j] 
       vp = v[i][j] 
       udot = (up**2 + vp**2) 
       cu = up*cx[k] + vp*cy[k] 
       f[k][i][j] = w[k] + w[k]*(3.0*cu + 4.5*cu**2 - 1.5*udot) 
    return f 

# Pull these out so that numba treats them as constant arrays 
cx = np.array([0., 1., 0., -1., 0., 1., -1., -1., 1.]) 
cy = np.array([0., 0., 1., 0., -1., 1., 1., -1., -1.]) 
w = np.array([4./9, 1./9, 1./9, 1./9, 1./9, 1./36, 1./36, 1./36, 1./36]) 

@nb.jit(nopython=True) 
def rownowaga_numba(u, v): 
    wymiar_x = u.shape[0] 
    wymiar_y = u[1].shape[0] 
    f = np.zeros((9, wymiar_x, wymiar_y)) 

    for i in xrange(wymiar_x): 
     for j in xrange (wymiar_y): 
      for k in xrange(9): 
       up = u[i,j] 
       vp = v[i,j] 
       udot = (up*up + vp*vp) 
       cu = up*cx[k] + vp*cy[k] 
       f[k,i,j] = w[k] + w[k]*(3.0*cu + 4.5*cu**2 - 1.5*udot) 
    return f 

現在讓我們設置了一些測試數組:

u = [[math.sin(i) + math.cos(j) for j in range(40)] for i in range(1000)] 
y = [[math.sin(i) + math.cos(j) for j in range(40)] for i in range(1000)] 

u_np = np.array(u) 
y_np = np.array(y) 

首先,讓我們來驗證一下我的numba代碼給同樣的答案爲OP代碼:

f1 = rownowaga(u, y) 
f2 = rownowaga_numba(u_np, y_np) 

從IPython的筆記本:

In [13]: np.allclose(f2, np.array(f1)) 
Out[13]: 
True 

現在讓我們來對我的筆記本電腦時的事情:

In [15] %timeit f1 = rownowaga(u, y) 
1 loops, best of 3: 288 ms per loop 


In [16] %timeit f2 = rownowaga_numba(u_np, y_np) 
1000 loops, best of 3: 973 µs per loop 

所以我們得到一個不錯的300X加速用最少的代碼改變。需要注意的是,我在0.22之前使用了一小時Numba版本:

In [16]: nb.__version__ 
Out[16]: 
'0.21.0+137.gac9929d' 
+0

這太神奇了! :) 謝謝 – Ricevind