2015-11-05 82 views
0

我有一個圖形如下,其中輸入x有兩條路徑到達y。它們與使用cMulTable的gModule結合使用。現在,如果我做gModule:向後(x,y),我得到一個包含兩個值的表格。它們是否對應於從兩條路徑導出的誤差導數? enter image description here通過gModule向後移動手電筒

但由於path2包含其他nn層,我想我需要以逐步的方式推導出此路徑中的派生值。但爲什麼我得到dy/dx的兩個值表?

爲了讓事情更清晰,代碼測試,這是如下:

input1 = nn.Identity()() 
input2 = nn.Identity()() 
score = nn.CAddTable()({nn.Linear(3, 5)(input1),nn.Linear(3, 5)(input2)}) 
g = nn.gModule({input1, input2}, {score}) #gModule 

mlp = nn.Linear(3,3) #path2 layer 

x = torch.rand(3,3) 
x_p = mlp:forward(x) 
result = g:forward({x,x_p}) 
error = torch.rand(result:size()) 
gradient1 = g:backward(x, error) #this is a table of 2 tensors 
gradient2 = g:backward(x_p, error) #this is also a table of 2 tensors 

那麼,什麼是錯我的步驟是什麼?

P.S,也許我已經找到了原因,因爲g:落後({x,x_p},error)導致同一個表。所以我猜這兩個值分別代表dy/dx和dy/dx_p。

回答

1

我想你只是犯了一個錯誤構建你的gModule。每nn.ModulegradInput必須與其input具有完全相同的結構 - 這是backprop的工作方式。

下面是一個例子,如何使用nngraph創建像你這樣的模塊:

require 'torch' 
require 'nn' 
require 'nngraph' 

function CreateModule(input_size) 
    local input = nn.Identity()() -- network input 

    local nn_module_1 = nn.Linear(input_size, 100)(input) 
    local nn_module_2 = nn.Linear(100, input_size)(nn_module_1) 

    local output = nn.CMulTable()({input, nn_module_2}) 

    -- pack a graph into a convenient module with standard API (:forward(), :backward()) 
    return nn.gModule({input}, {output}) 
end 


input = torch.rand(30) 

my_module = CreateModule(input:size(1)) 

output = my_module:forward(input) 
criterion_err = torch.rand(output:size()) 

gradInput = my_module:backward(input, criterion_err) 
print(gradInput) 

UPDATE

正如我所說的,每nn.ModulegradInput必須具有完全相同的結構爲input。因此,如果您將模塊定義爲nn.gModule({input1, input2}, {score}),則您的gradOutput(反向傳遞的結果)將是一個漸變w.r.t表。 input1input2,你的情況是xx_p

唯一的問題是:爲什麼在地球上沒有出現錯誤時調用:

gradient1 = g:backward(x, error) 
gradient2 = g:backward(x_p, error) 

例外必須提高,因爲第一個參數必須是不是張量而是兩個張量的表。那麼,在計算:backward(input, gradOutput)時,大多數(也許是所有)火炬模塊都不會使用input參數(它們通常會存儲最後:forward(input)呼叫中的input副本)。事實上,這個論點是無用的,模塊甚至不打擾自己去驗證它。

+0

嗨亞歷克斯,謝謝你的回答。我沒有使用單個輸入x,而是使用兩個輸入a和b創建了gModule,而b的值取決於a。我這樣做是因爲nn層比線性轉換更復雜。它有一個LSTM結構。 –

+1

我還包括了我的代碼模擬,請查看@Alexander Lutsenko –