2017-02-18 77 views
2

兩個矢量的點積可以通過numpy.dot來計算。現在我想計算矢量的陣列的點積:兩個矢量點積的可伸縮解決方案

>>> numpy.arange(15).reshape((5, 3)) 
array([[ 0, 1, 2], 
     [ 3, 4, 5], 
     [ 6, 7, 8], 
     [ 9, 10, 11], 
     [12, 13, 14]]) 

載體是行向量和輸出應該是包含從點積的結果一維陣列:

array([ 5, 50, 149, 302, 509]) 

對於交叉產品(numpy.cross),可以通過指定axis關鍵字輕鬆實現。但numpy.dot沒有這樣的選項,並傳遞它兩個2d數組將導致普通的矩陣產品。我也看了一下numpy.tensordot,但這似乎也沒有做這項工作(作爲一個擴展矩陣產品)。

我知道我可以經由

>>> numpy.einsum('ij, ji -> i', array2d, array2d.T) 

計算每用於2D陣列元件的點積然而這種解決方案並沒有對於1D陣列(即僅單個元素)工作。我想獲得一個適用於1d數組(返回標量)和1d數組(返回1d數組)的數組的解決方案。

回答

4

使用np.einsumellipsis(...)以考慮與可變維數,數組像這樣 -

np.einsum('...i,...i ->...', a, a) 

在其上陳述的文檔 -

要啓用和控制廣播,使用省略號。默認 NumPy風格的廣播是通過在 每一項的左側添加省略號來完成的,如np.einsum('... ii - > ... i',a)。沿着 第一個和最後一個軸線的軌跡,你可以做np.einsum('i ... i',a),或者做一個 矩陣矩陣乘積,而不是最右邊的指標, 你可以做np.einsum('ij ...,jk ...-> ik ...',a,b)。

樣品上2D1D陣列運行 -

In [88]: a 
Out[88]: 
array([[ 0, 1, 2], 
     [ 3, 4, 5], 
     [ 6, 7, 8], 
     [ 9, 10, 11], 
     [12, 13, 14]]) 

In [89]: np.einsum('...i,...i ->...', a, a) # On 2D array 
Out[89]: array([ 5, 50, 149, 302, 509]) 

In [90]: b = a[:,0] 

In [91]: b 
Out[91]: array([ 0, 3, 6, 9, 12]) 

In [92]: np.einsum('...i,...i ->...', b,b) # On 1D array 
Out[92]: 270 

運行測試 -

因爲,我們需要保持一個軸對齊,至少有2D陣列的np.einsumnp.matmul或一個最新的@ operator會很有效率。

In [95]: a = np.random.rand(1000,1000) 

# @unutbu's soln 
In [96]: %timeit (a*a).sum(axis=-1) 
100 loops, best of 3: 3.63 ms per loop 

In [97]: %timeit np.einsum('...i,...i ->...', a, a) 
1000 loops, best of 3: 944 µs per loop 

In [98]: a = np.random.rand(1000) 

# @unutbu's soln 
In [99]: %timeit (a*a).sum(axis=-1) 
100000 loops, best of 3: 9.11 µs per loop 

In [100]: %timeit np.einsum('...i,...i ->...', a, a) 
100000 loops, best of 3: 5.59 µs per loop 
+0

很好的解決方法,謝謝!我還發現可以使用'numpy.linalg.norm(array,axis = -1)** 2',但是性能類似於'(a * a).sum(axis = -1)'(即相對於'einsum'),我發現這很令人驚訝,因爲我沒有看到'einsum'的調用與'linalg.norm'函數不同(除了平方,但我驗證了這個貢獻可以忽略不計)。你有一個想法,爲什麼這兩個功能有這樣不同的表現? (除非'linalg.norm'的計算結果爲'sqrt((a * a).sum(axis = -1))',但是什麼是'einsum'有什麼不同?) –

+0

@a_guest嗯,我還沒有深入' einsum的實現,但基於其性能,我認爲它是一個C編譯的代碼。現在,用'numpy.linalg。規範「的基礎上,我認爲元素乘法,然後求和是這樣的編譯代碼的兩個階段,並與'numpy.linalg.norm'頂部,我們需要sqrt,然後你需要做平方當使用'einsum'將其與單個編譯代碼進行比較時,找回所需的輸出只是太多工作。再次聲明:不是與這些相關的編譯細節方面的專家。 – Divakar

相關問題