0
我想在訓練期間預加載訓練數據CNN
in TF
我的簡單實現如下。但是,我發現一個奇怪的現象。這似乎是一個同步過程。裝載一批數據的時間成本幾乎相同,無論是PRE_FETCH
是True
還是False
。TF中的預加載數據
class Demo(object):
def __init__(self):
self._name = 'demo'
def load_batch(self):
...
def prefetch(self, func):
while True:
data = func()
self.queue.put(data)
def train(self):
input_data = tf.placeholder(tf.float32, shape=[B, H, W, C])
optim_op = build_model(input_data)
if PRE_FETCH:
self.queue = Queue(30)
self.process = Process(target=self.prefetch, args=(self.load_batch))
self.process.start()
def cleanup():
self.process.terminate()
self.process.join()
import atexit
atexit.register(cleanup)
sess = tf.Session()
i = 1
while i < MAX_ITER_SIZE:
if PRE_FETCH:
start = time.time()
tmp = self.queue.get()
end = time.time()
print 'load data time: ', (end - start)
else:
start = time.time()
tmp = self.load_batch()
end = time.time()
print 'load data time: ', (end - start)
sess.run(optim_op, feed_dict={input_data: tmp}