2017-08-07 58 views
2

摘要:使用新的tf.contrib.data.Dataset使圖形protobuff文件的大小加倍,我無法在Tensorboard中顯示圖形。Tensorflow數據集API雙打圖原始文件大小

細節:

我一起tf.contrib.learn.Experiment框架嘗試新TensorFlow tf.contrib.data.Dataset功能。我的輸入數據被定義爲input functions,它返回特徵和標籤的張量。

如果我創建具有tf.train.slice_input_producer功能我輸入功能就像下面的代碼塊(完整的代碼here),然後我得到的graph.pbtxt文件是620M和.meta文件大小約爲165M。

def train_inputs(): 
    with tf.name_scope('Training_data'): 
     x = tf.constant(mnist.train.images.reshape([-1, 28, 28, 1])) 
     y = tf.constant(mnist.train.labels) 
     sliced_input = tf.train.slice_input_producer(
      tensor_list=[x, y], shuffle=True) 
     return tf.train.shuffle_batch(
      sliced_input, batch_size=batch_size, 
      capacity=10000, min_after_dequeue=batch_size*10) 

現在,如果我創造我輸入功能與新tf.contrib.data.Dataset.from_tensor_slices就像下面的代碼塊(完整的代碼here),然後我得到的graph.pbtxt文件的大小加倍到1.3G和.meta文件一倍大小330M。

def train_inputs(): 
    with tf.name_scope('Training_data'): 
     images = mnist.train.images.reshape([-1, 28, 28, 1]) 
     labels = mnist.train.labels 
     dataset = tf.contrib.data.Dataset.from_tensor_slices(
      (images, labels)) 
     dataset = dataset.repeat(None) # Infinite 
     dataset = dataset.shuffle(buffer_size=10000) 
     dataset = dataset.batch(batch_size) 
     iterator = dataset.make_one_shot_iterator() 
     next_example, next_label = iterator.get_next() 
     return next_example, next_label 

現在因爲graph.pbtxt文件是如此之大TensorBoard需要年齡來解析這個文件,我無法直觀地調試我的模型圖。 我發現在Dataset documentation,這種增加的大小來自:「陣列的內容將被複制多次」solution將使用佔位符。使用tf.contrib.learn.Experiment框架時不過

sess.run(iterator.initializer, feed_dict={features_placeholder: features, labels_placeholder: labels}) 

這似乎,是在我的掌握:然而,在這種情況下,我需要在numpy的陣列送入佔位符有活動會話初始化迭代器。

如何使用實驗框架初始化迭代器的初始化程序?或者找到解決方案來使用數據集API而不增加我的圖表大小?

回答

2

我使用tf.train.SessionRunHook發現了我的問題的解決方案。我創建了一個SessionRunHook對象初始化會話建立後的迭代器:

class IteratorInitializerHook(tf.train.SessionRunHook): 
    def __init__(self): 
     super(IteratorInitializerHook, self).__init__() 
     self.iterator_initiliser_func = None 

    def after_create_session(self, session, coord): 
     self.iterator_initiliser_func(session) 

創建的數據集迭代器時,初始化函數設置:

iterator_initiliser_hook.iterator_initiliser_func = \ 
    lambda sess: sess.run(
     iterator.initializer, 
     feed_dict={images_placeholder: images, 
        labels_placeholder: labels}) 

我在掛鉤對象傳遞給train_monitorseval_hooks參數tf.contrib.learn.Experiment

生成的graph.pbtxt文件現在只有500K,而.meta文件只有244K。

Full example here.

+0

不錯。也解決了我的問題。但似乎是一種解決方法? 我的帖子:https://stackoverflow.com/questions/46207211/tensorflow-dataset-api-causes-graph-size-to-explode –

相關問題