2017-06-05 57 views
4

對於傳輸學習,通常使用網絡作爲特徵提取器來創建要素的數據集,在該數據集上訓練另一個分類器(例如, SVM)。TensorFlow:tf.contrib.data中的「無法通過值捕獲有狀態節點」API

我想實現這個使用DataSet API(tf.contrib.data)和dataset.map()

# feature_extractor will create a CNN on top of the given tensor 
def features(feature_extractor, ...): 
    dataset = inputs(...) # This creates a dataset of (image, label) pairs 

    def map_example(image, label): 
     features = feature_extractor(image, trainable=False) 
     # Leaving out initialization from a checkpoint here... 
     return features, label 

    dataset = dataset.map(map_example) 

    return dataset 

做創建數據集的迭代器時失敗。

ValueError: Cannot capture a stateful node by value. 

這是事實,網絡的內核和偏見是變量,因此是有狀態的。對於這個特殊的例子,他們不一定非要。

是否有辦法讓Ops和特定的對象無狀態?

由於我使用的是tf.layers,我不能簡單地將它們創建爲常量,並且設置trainable=False既不會創建常量,也不會將變量添加到GraphKeys.TRAINABLE_VARIABLES集合中。

回答

9

不幸的是,tf.Variable本質上是有狀態的。但是,如果您使用Dataset.make_one_shot_iterator()創建迭代器,則只會出現此錯誤*爲避免此問題,您可以改爲使用Dataset.make_initializable_iterator(),但必須同時在返回的迭代器上運行iterator.initializer,然後運行tf.Variable對象的初始化程序用於輸入管道。


*這樣做的原因限制是Dataset.make_one_shot_iterator()實現細節和工作正在進行TensorFlow功能(Defun)支持,它使用封裝的數據集定義。由於使用諸如查找表和變量之類的有狀態資源比我們最初想象的更受歡迎,因此我們正在研究如何放鬆這一限制。