2016-11-12 61 views
1

我有一個大小(n, m)np.uint8(所以它只包含[0, 255]中的值)大小X numpy陣列。我也有一個映射f[0, 255][0, 3]有效填充張量在numpy

我想創建一個形狀爲(4, n, m)的數組Y,使得y_{k, i, j} = 1 if k == f(x_{i, j})和0,否則。現在,我這樣做:

Y = np.zeros((4, n, m)) 
for i in range(256): 
    Y[f(i), X == i] = 1 

但這是超級慢,我不能找到一個更有效的方式來做到這一點。有任何想法嗎?

+0

你能分享一下func'f'的實現嗎? – Divakar

+0

嗯,分享是不容易的,因爲它是一個非常具體的問題函數,但是它是一個不需要時間執行的函數,並且它沒有任何東西可以用來使整個過程更快。它可能是這樣的: def f(x):v [x] 其中v = np.random.randint(4,size =(256,)) – dhokas

回答

1

假設f可以在所有迭代值進行操作一走,你可以使用broadcasting -

Yout = (f(X) == np.arange(4)[:,None,None]).astype(int) 

運行測試和驗證 -

In [35]: def original_app(X,n,m): 
    ...:  Y = np.zeros((4, n, m)) 
    ...:  for i in range(256): 
    ...:   Y[f(i), X == i] = 1 
    ...:  return Y 
    ...: 

In [36]: # Setup Inputs 
    ...: n,m = 2000,2000 
    ...: X = np.random.randint(0,255,(n,m)).astype('uint8') 
    ...: v = np.random.randint(4, size=(256,)) 
    ...: def f(x): 
    ...:  return v[x] 
    ...: 

In [37]: Y = original_app(X,n,m) 
    ...: Yout = (f(X) == np.arange(4)[:,None,None]).astype(int) 
    ...: 

In [38]: np.allclose(Yout,Y) # Verify 
Out[38]: True 

In [39]: %timeit original_app(X,n,m) 
1 loops, best of 3: 3.77 s per loop 

In [40]: %timeit (f(X) == np.arange(4)[:,None,None]).astype(int) 
10 loops, best of 3: 74.5 ms per loop 
1

標索引和布爾值的組合出現傷害你的速度:

In [706]: %%timeit 
    ...: Y=np.zeros((4,3,4)) 
    ...: for i in range(256): 
    ...: Y[f(i), X==i]+=1 
    ...: 

100 loops, best of 3: 12.5 ms per loop 

In [722]: %%timeit 
    ...: Y=np.zeros((4,3,4)) 
    ...: for i in range(256): 
    ...:  I,J=np.where(X==i) 
    ...:  Y[f(i),I,J] = 1 
    ...: 
100 loops, best of 3: 8.55 ms per loop 

這是爲

X=np.arange(12,dtype=np.uint8).reshape(3,4) 
def f(i): 
    return i%4 

在這種情況下,f(i)不是主要的時間消費:

In [718]: timeit K=[f(i) for i in range(256)] 
10000 loops, best of 3: 120 µs per loop 

,但得到的X==i指標是緩慢

In [720]: timeit K=[X==i for i in range(256)] 
1000 loops, best of 3: 1.29 ms per loop 
In [721]: timeit K=[np.where(X==i) for i in range(256)] 
100 loops, best of 3: 2.73 ms per loop 

我們需要重新思考的的X==i部分映射,而不是f(i)部分。

=====================

壓扁的最後2個維度幫助;

In [780]: %%timeit 
    ...: X1=X.ravel() 
    ...: Y=np.zeros((4,12)) 
    ...: for i in range(256): 
    ...:  Y[f(i),X1==i]=1 
    ...: Y.shape=(4,3,4) 
    ...: 
100 loops, best of 3: 3.16 ms per loop 
+0

令我驚訝的是 - 布爾索引不會轉換爲非零「無論如何,在罩下?你的第二個例子已經在我的電腦上下了一個'+ =' – Eric

+0

測試 - 你在第一部分中觀察到的大部分時間差異都是由於'='vs'+ ='錯字,可惜 – Eric

+0

你是對的。我一直在測試'+ =',因爲我正在考慮需要無緩衝的'add.at'的情況('X'中的重複值)。 – hpaulj