2017-04-19 55 views
0

我想在訓練期間預加載訓練數據CNN in TF我的簡單實現如下。但是,我發現一個奇怪的現象。這似乎是一個同步過程。裝載一批數據的時間成本幾乎相同,無論是PRE_FETCHTrue還是FalseTF中的預加載數據

class Demo(object): 
    def __init__(self): 
     self._name = 'demo' 

    def load_batch(self): 
     ... 

    def prefetch(self, func): 
     while True: 
      data = func() 
      self.queue.put(data) 

    def train(self): 
     input_data = tf.placeholder(tf.float32, shape=[B, H, W, C]) 
     optim_op = build_model(input_data) 

     if PRE_FETCH: 
      self.queue = Queue(30) 
      self.process = Process(target=self.prefetch, args=(self.load_batch)) 
      self.process.start() 
      def cleanup(): 
       self.process.terminate() 
       self.process.join() 
      import atexit 
      atexit.register(cleanup) 
     sess = tf.Session() 
     i = 1 
     while i < MAX_ITER_SIZE: 
      if PRE_FETCH: 
       start = time.time() 
       tmp = self.queue.get() 
       end = time.time() 
       print 'load data time: ', (end - start) 
      else: 
       start = time.time() 
       tmp = self.load_batch() 
       end = time.time() 
       print 'load data time: ', (end - start) 
      sess.run(optim_op, feed_dict={input_data: tmp} 

回答

0

需要花費時間的是通過佔位符將數據加載到圖中。如果你希望你的預加載有效,你應該調查替換你的python隊列並用tensorflow圖中的操作線程mecanisme。在tensorflow網站上有一個很好的教程:https://www.tensorflow.org/programmers_guide/reading_data