2017-06-13 448 views
6

在numpy的,我可以做一個簡單的矩陣乘法這樣的:怎麼辦矩陣的點積在PyTorch

a = numpy.arange(2*3).reshape(3,2) 
b = numpy.arange(2).reshape(2,1) 
print(a) 
print(b) 
print(a.dot(b)) 

然而,當我想這跟PyTorch張量,這不起作用:

a = torch.Tensor([[1, 2, 3], [1, 2, 3]]).view(-1, 2) 
b = torch.Tensor([[2, 1]]).view(2, -1) 
print(a) 
print(a.size()) 

print(b) 
print(b.size()) 

print(torch.dot(a, b)) 

此代碼引發以下錯誤:

RuntimeError: inconsistent tensor size at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:503

任何想法如何簡單的點積可以P中進行yTorch?

回答

13

您正在尋找

torch.mm(a,b) 

注意torch.dot()表現不同來np.dot()。關於什麼是可取的here有一些討論。具體而言,torch.dot()ab作爲1D向量(不考慮它們的原始形狀)並計算它們的內積。錯誤被拋出,因爲這種行爲使得你的a長度爲6的矢量,而你的b長度爲2的矢量;因此不能計算其內積。對於PyTorch中的矩陣乘法,請使用torch.mm()。 Numpy的np.dot()相比之下更加靈活;它計算一維數組的內積併爲二維數組執行矩陣乘法。

5

大廈mexmex回答,如果你想要做一個矩陣乘法,你能做到這一點的方法有三種:

AB = A.mm(B) # computes A.B (matrix multiplication) 
# or 
AB = torch.mm(A, B) 
# or even simpler 
AB = A @ B # Python 3.5+ 

對於逐元素相乘,你可以簡單地做(如果A和B具有相同的形狀)

A * B # element-wise matrix multiplication (Hadamard product)