2016-11-28 94 views
3

在一個大的代碼庫,我使用np.broadcast_to廣播陣列(只是用簡單的例子在這裏):取消廣播numpy的陣列

In [1]: x = np.array([1,2,3]) 

In [2]: y = np.broadcast_to(x, (2,1,3)) 

In [3]: y.shape 
Out[3]: (2, 1, 3) 

其他地方在代碼中,我使用第三方功能,可以運行以Numpy數組的矢量化方式,但不是ufuncs。這些函數不理解廣播,這意味着在像y這樣的數組上調用這樣的函數效率不高。諸如Numpy的vectorize這樣的解決方案不好,因爲雖然他們理解廣播,但是他們在數組元素上引入了一個for循環,這是非常低效的。

理想情況下,我希望能夠做的是有一個函數,我們可以調用例如unbroadcast,它返回一個最小形狀的數組,如果需要可以將其廣播回全尺寸。因此,例如爲:

In [4]: z = unbroadcast(y) 

In [5]: z.shape 
Out[5]: (1, 1, 3) 

然後我就可以在z運行第三方功能,那麼廣播結果返回給y.shape

有沒有一種方法可以實現unbroadcast,它依賴於Numpy的公共API?如果沒有,是否有任何黑客可以產生期望的結果?

+0

「y [None,0]'? – Divakar

+0

你是什麼意思「最小的形狀」? 'unbroadcast'返回的N維的最小形狀是不是總是'(1,1,...,1)'(甚至是'(1,)')? –

+0

我的意思是最小的形狀,仍然包含所有需要的數據將其廣播回完整陣列。所以在上面的例子中,「z.shape」是「(1,1,3)」而不是「(1,1,1)」。 – astrofrog

回答

3

這大概相當於自己的解決方案,只有一點點更內置-在。它採用as_stridednumpy.lib.stride_tricks

import numpy as np 
from numpy.lib.stride_tricks import as_strided 

x = np.arange(16).reshape(2,1,8,1) # shape (2,1,8,1) 
y = np.broadcast_to(x,(2,3,8,5)) # shape (2,3,8,5) broadcast 

def unbroadcast(arr): 
    #determine unbroadcast shape 
    newshape = np.where(np.array(arr.strides) == 0,1,arr.shape) # [2,1,8,1], thanks to @Divakar 
    return as_strided(arr,shape=newshape) # strides are automatically set here 

z = unbroadcast(x) 
np.all(z==x) # is True 

注意,在我原來的答案,我並沒有定義一個函數,並將得到z陣列有(64,0,8,0)strides,而輸入具有(64,64,8,8)。在當前版本中,返回的z數組與x有相同的大小,我猜想傳遞並返回數組強制創建副本。無論如何,我們可以在任何情況下手動設置as_strided以獲得相同的數組,但這在上述設置中似乎不是必需的。

+1

或'np.where(np.array(y.strides)== 0,1,y.shape)'爲新聞形式? – Divakar

+0

@Divakar權利,我一直忘記'np.where' :)顯然這是numpy-idiomatic版本,謝謝。 –

+1

我的解決方案非常好的改進 - 謝謝! – astrofrog

4

我有一個可能的解決方案,所以會在這裏發佈(但是如果有人有更好的,請隨時回覆!)。一個解決方案是檢查strides說法陣列,這將是0一起廣播尺寸:

def unbroadcast(array): 
    slices = [] 
    for i in range(array.ndim): 
     if array.strides[i] == 0: 
      slices.append(slice(0, 1)) 
     else: 
      slices.append(slice(None)) 
    return array[slices] 

這給:

In [14]: unbroadcast(y).shape 
Out[14]: (1, 1, 3)