2017-09-04 47 views
1

我用Normalization圖層訓練了一個模型。該代碼是這樣的:如何預測何時在Keras中使用規範化或丟失層進行訓練?

在訓練階段:

model=Sequential() 
model.add() 

... 

k.set_learning_phase(1) 
ModelCheckpoint(weights_file) 
model.fit() 

在推理時間:

k.set_learning_phase(0) 
model.load_weights(weights_file) 
model.predict_classes() 

... 

Keras的版本:2.0.8。這是正確的,還是需要一些特殊的代碼來計算經過訓練後的BN,例如在Caffe中使用SegNet?

回答

2

不,您在使用BatchNormalization或Dropout圖層時不需要做任何特別的事情。凱拉斯已經跟蹤學習/測試階段,所以當使用predictpredict_classes時,它是正確的。

你甚至不需要手動設置學習階段,凱拉斯已經做到了。

+0

感謝您的及時回覆。我知道了。 – spider

相關問題