2017-07-17 147 views
2

我想在Keras培訓GAN。我的最終目標是BEGAN,但我從最簡單的一個開始。瞭解如何正確凍結重量是必要的,這就是我所苦苦掙扎的。如何在Keras中編譯模型後動態凍結權重?

在生成器訓練時間期間,鑑別器權重可能不會更新。我想凍結解凍鑑別交替爲訓練發生器和鑑別器。問題是,設置可訓練的參數爲false,鑑別器模型或甚至在其'權重不停止模型來訓練(和權重更新)。另一方面,當我在將可訓練設置爲False之後編譯模型時,權重變爲unfreezable。我無法在每次迭代後編譯模型,因爲這否定了整個訓練的想法。

由於這個問題,似乎很多Keras的實現都被竊聽,或者他們的工作是因爲舊版本中的一些非直觀的技巧或某事。

回答

1

我幾個月前嘗試這個示例代碼和它的工作: https://github.com/fchollet/keras/blob/master/examples/mnist_acgan.py

這不是GAN中最簡單的形式,但據我記得,這不是太難以去除的分類損失將模型轉換成GAN。

您不需要打開/關閉鑑別器的可訓練屬性並重新編譯。只需創建並編譯兩個模型對象,一個帶有trainable=True(代碼中爲discriminator),另一個帶有trainable=False(代碼中爲combined)。

當您更新鑑別器時,請致電discriminator.train_on_batch()。當您更新發生器時,請致電combined.train_on_batch()

0

您是否可以使用tf.stop_gradient來有條件地凍結重量?

+0

'tf.stop_gradient'正在阻止漸變流動,這不是我想要實現的。我想爲梯度計算梯度流量和計算梯度,但不執行更新操作。 –

+0

然後,您可能會更好地將要更新的變量列表明確傳遞給tensorflow更新操作,而不是始終凍結/解凍加權。 –

+0

你說得對,但它是Tensorflow解決方案,Keras不允許這樣做。你有一個模型抽象,你主要有'fit'和'train_on_batch'方法,就這些。如果在純Keras中沒有解決方案,那麼我將切換到Tensorflow。 –

0

也許你的敵對網絡(發電機加辨別器)寫在'模型'中。然而,即使你設置了d.trainable = False,獨立d網絡也是不可訓練的,但整個對抗網絡中的d仍然可以訓練。

您可以在設置之後使用d_on_g.summary()d.trainable = False並且您會知道我的意思(注意可訓練變量)。