2016-09-06 1803 views
1

我想在火車組(is_training=True)和驗證集(is_training=False)上運行給定模型,具體說明如何應用dropout。現在,prebuilt models公開了一個參數is_training,在構建網絡時,該參數傳遞給dropout層。問題是,如果我使用不同的值is_training兩次調用方法,我會得到兩個不分享權重的網絡(我認爲?)。我怎樣才能讓兩個網絡共享相同的權重,以便我可以運行我在驗證集上訓練過的網絡?帶有is_training True和False的Tensorflow(tf-slim)模型

+0

我覺得默認的行爲是共享兩種情況之間的權重,所以你不要有什麼關係。 'tf-slim'使用'tf.get_variable()'來重用調用之間的變量。 –

+0

好的,我認爲這主要是有效的。你需要確保'scope'被設置,然後爲了安全,最好也設置'reuse = True'。 –

回答

1

我寫了一個解決方案,您的評論在列車和測試模式中使用Overfeat。 (我無法測試,所以你可以檢查是否正常工作?)

一是部分進口及參數:

import tensorflow as tf 
slim = tf.contrib.slim 
overfeat = tf.contrib.slim.nets.overfeat 

batch_size = 32 
inputs = tf.placeholder(tf.float32, [batch_size, 231, 231, 3]) 
dropout_keep_prob = 0.5 
num_classes = 1000 

在訓練模式,我們通過正常範圍到功能overfeat

scope = 'overfeat' 
is_training = True 

output = overfeat.overfeat(inputs, num_classes, is_training,   
          dropout_keep_prob, scope=scope) 

然後在測試模式下,我們創建了與reuse=True相同的範圍。

scope = tf.VariableScope(reuse=True, name='overfeat') 
is_training = False 

output = overfeat.overfeat(inputs, num_classes, is_training,   
          dropout_keep_prob, scope=scope) 
0

你可以只使用一個佔位符is_training:

isTraining = tf.placeholder(tf.bool) 

# create nn 
net = ... 
net = slim.dropout(net, 
        keep_prob=0.5, 
        is_training=isTraining) 
net = ... 

# training 
sess.run([net], feed_dict={isTraining: True}) 

# testing 
sess.run([net], feed_dict={isTraining: False}) 
+1

我試過這個,並且遇到了問題,因爲變量沒有被重用。我也遇到了我無法解釋的內存限制。 –

0

這要看情況下,解決方案是不同的。

我的第一個選擇是使用不同的流程來進行評估。你只需要檢查是否有新的關卡和加載權納入評價網絡(與is_training=False):

checkpoint = tf.train.latest_checkpoint(self.checkpoints_path) 
# wait until a new check point is available 
while self.lastest_checkpoint == checkpoint: 
    time.sleep(30) # sleep 30 seconds waiting for a new checkpoint 
    checkpoint = tf.train.latest_checkpoint(self.checkpoints_path) 
logging.info('Restoring model from {}'.format(checkpoint)) 
self.saver.restore(session, checkpoint) 
self.lastest_checkpoint = checkpoint 

第二個選項是每一個時代後您卸載圖形,並創建一個新的評價用圖。這個解決方案浪費了很多時間加載和卸載圖形。

第三個選項是分享權重。但是給這些網絡添加隊列或數據集可能會導致問題,所以您必須非常小心。我只用於連體網絡。

with tf.variable_scope('the_scope') as scope: 
    your_model(is_training=True) 
    scope.reuse_variables() 
    your_model(is_training=False)