2017-04-05 97 views
2

我正在嘗試在keras中使用theano後端實現梯度範數的正則化術語improved WGAN training。基本上我想基於它是有多遠從1由於自定義丟失函數,Keras拋出DisconnectedInputError

我實現這樣一個自定義的損失,懲罰梯度的L2範數:

def get_gradient_norm(model, y_pred): 
    weights = model.trainable_weights 
    gradients = model.optimizer.get_gradients(K.mean(y_pred), weights) 
    acc = None 
    for g in gradients: 
     s = K.sum(K.square(g)) 
     if acc == None: 
      acc = s 
     else: 
      acc = s + acc 
    return K.sqrt(acc) 

def make_w_reg_loss(model): 
    lvar = K.variable(lamb, name="Lambda") 

    def foo(y_true, y_pred): 
     gnorm = get_gradient_norm(model, y_pred) 
     return lvar * K.square(gnorm - 1) 

return foo 

[...] 

critic.compile(loss=make_w_reg_loss(critic), optimizer=RMSprop(learn_rate)) 

它拋出一個DisconnectedInputError一次訓練過程中嘗試嘗試獲取我自定義丟失函數的漸變。

爲什麼?

用一些標準損失工作替換損失。這個錯誤是關於我定義的損失函數的。

看到這個要點我嘗試a minimal not-working example

編輯:

所以我想我知道如何使它現在的工作。 首先,我只是隨機添加這個詞來我的損失,直接從富返回之前(y_true,y_pred):

K.mean(y_pred) - K.mean(y_pred) 

顯然是一個常量零,如果我只能用這個詞作爲我的損失怎麼辦得到零。 但是,如果我將這個「常量零」添加到我的正則化損失中,它突然正常工作。我從正規化中獲得了非零的損失,並且許多train_on_batch的優化確實也減少了損失。

所以這是一個奇怪的問題,theano有點過分投擲異常?我的問題仍然存在:爲什麼它會拋出原始代碼。由於添加一個固定的零項修復它,它對我來說看起來像一個錯誤?

回答

0

我真的很想在keras中實現這個改進的wgan,我很驚訝地看到你是如何解決你的「問題」的。您是否驗證過您的wgan-gp損失是否按預期運行的實驗性實驗? 它應該很容易檢查,它是一個非常穩定的訓練,使您可以使用非常深的鑑別器;) 我想做你做的同樣的工作,但是與tensorflow後端,我會嘗試看看你的代碼和代碼在這裏:keras improved wgan

我會很高興聽到您的更新,我會再次寫在這裏,只要我有一個wgan-gp工作代碼在keras/tensorflow! P.S.上面的鏈接正在執行張量流代碼中的所有過程,迫使使用tf訓練功能。我真的很喜歡你的方法,在那裏我們可以簡單地定義一個keras損失,使用我們所有通常的keras高級API進行訓練;)

編輯:從你的代碼看,你完全可以用K後端,所以你的代碼應該輕鬆運行tensorflow後端。您是否嘗試更改後端以檢查問題/錯誤是否與Theano真正相關?

第二編輯:您正在計算梯度w.r.t的權重,但在wgan-gp紙張中,梯度損失是從梯度w.r.t開始計算生成樣本和實際樣本之間的平均樣本。這會帶來非常不同的結果。 在下面的鏈接,你可以找到一個非常好的改進wgan損失的實施,對theano是可能工作過: https://github.com/farizrahman4u/keras-contrib/

+0

我張貼的代碼是一個殘酷的削減版本,肯定不能正確實現什麼,這是隻是爲了展示這個問題。 我的真實代碼通過在真實樣本和假樣本之間傳遞插值數據點來實現採樣。目前爲止我只測試了玩具的例子,但他們看起來很有希望。但是,更多的「真實」工作讓我望而卻步,所以我無法測試出更復雜的數據集。 –

+0

我沒有測試tensorflow,沒有在這裏安裝它,因爲最終損失函數包含更多的術語,異常問題不是真正的問題。它只是讓我困惑。 我猜你發佈的wgan實現可能是由具有更多keras經驗的人編寫的,並且有更好的文檔記錄。當我回過頭來看時,我可能會用到那個,因爲它似乎是在GPU上實現插值部分,我是用CPU來完成的。涼! –

+0

我最終在調試中失去了12個小時,試圖修改我鏈接的代碼,以便作爲單獨的梯度損失損失工作(而不是集成到鑑別器中),並且我很快卡在牆上「tensorflow獲得無損」類型的錯誤。我突然記起你的修復,而且,我也在修復。沒有你的修復,如果我通過model.summary()可視化模型,沒有輸入層。用你簡單的修復,突然輸入層顯示爲輸入(並且梯度損失損失不起作用) –

相關問題