2017-02-28 190 views
8

我試圖解決的問題如下: 我有一個文件名列表trainimgs的文件名。我與它的capacity=len(trainimgs)min_after_dequeue=0定義的TensorFlow:從多個線程入隊和出隊隊列

  • tf.RandomShuffleQueue
  • 對於指定的epochlimit,此tf.RandomShuffleQueue預計將填充trainimgs次數。
  • 許多線程預計並行工作。每個線程從tf.RandomShuffleQueue中取出一個元素,並對其執行一些操作並將其排入另一個隊列。我有這個權利。
  • 但是,一旦1 epochtrainimgs已被處理且tf.RandomShuffleQueue爲空,假設當前時期e < epochlimit,隊列必須再次被填滿並且線程必須再次工作。

好消息是:我有它一定的情況下工作(見PS在最後!!)

壞消息是:我認爲有這樣做的更好的辦法這個。

我使用要做到這一點,現在是如下所述的方法(I簡化了功能和已刪除基礎的電子圖像處理的預處理和隨後的入隊但處理的心臟保持相同!!):

with tf.Session() as sess: 
    train_filename_queue = tf.RandomShuffleQueue(capacity=len(trainimgs), min_after_dequeue=0, dtypes=tf.string, seed=0) 
    queue_size = train_filename_queue.size() 
    trainimgtensor = tf.constant(trainimgs) 
    close_queue = train_filename_queue.close() 
    epoch = tf.Variable(initial_value=1, trainable=False, dtype=tf.int32) 
    incrementepoch = tf.assign(epoch, epoch + 1, use_locking=True) 
    supplyimages = train_filename_queue.enqueue_many(trainimgtensor) 
    value = train_filename_queue.dequeue() 

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 
    sess.run(init_op) 
    coord = tf.train.Coordinator() 
    tf.train.start_queue_runners(sess, coord) 
    sess.run(supplyimages) 
    lock = threading.Lock() 
    threads = [threading.Thread(target=work, args=(coord, value, sess, epoch, incrementepoch, supplyimages, queue_size, lock, close_queue)) for i in range(200)] 
    for t in threads: 
     t.start() 
    coord.join(threads) 

功函數如下:

def work(coord, val, sess, epoch, incrementepoch, supplyimg, q, lock,\ 
     close_op): 
while not coord.should_stop(): 
    if sess.run(q) > 0: 
     filename, currepoch = sess.run([val, epoch]) 
     filename = filename.decode(encoding='UTF-8') 
     print(filename + ' ' + str(currepoch)) 
    elif sess.run(epoch) < 2: 
     lock.acquire() 
     try: 
      if sess.run(q) == 0: 
       print("The previous epoch = %d"%(sess.run(epoch))) 
       sess.run([incrementepoch, supplyimg]) 
       sz = sess.run(q) 
       print("The new epoch = %d"%(sess.run(epoch))) 
       print("The new queue size = %d"%(sz)) 
     finally: 
      lock.release() 
    else: 
     try: 
      sess.run(close_op) 
     except tf.errors.CancelledError: 
      print('Queue already closed.') 
     coord.request_stop() 
return None 

所以,儘管這個作品,我有一種感覺,有一個更好的和更清潔的方式來實現這一目標。所以,簡而言之,我的問題是:

  1. 在TensorFlow中實現此任務有沒有更簡單更清晰的方法?
  2. 這段代碼的邏輯有問題嗎?我對多線程場景不是很有經驗,所以任何忽略我的注意力的明顯缺陷都會對我很有幫助。

P.S:看來這段代碼並不完美。當我運行120萬個圖像和200個線程時,它運行。然而,當我運行了10張和20個線程,它提供了以下錯誤:

CancelledError (see above for traceback): RandomShuffleQueue '_0_random_shuffle_queue' is closed. 
    [[Node: random_shuffle_queue_EnqueueMany = QueueEnqueueManyV2[Tcomponents=[DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](random_shuffle_queue, Const)]] 

我想我得到了涵蓋except tf.errors.CancelledError。這到底是怎麼回事 ?

回答

3

我終於找出答案。問題是多個線程在work()函數的各個點上發生衝突。 以下work()函數完美工作。

def work(coord, val, sess, epoch, maxepochs, incrementepoch, supplyimg, q, lock, close_op): 
    print('I am thread number %s'%(threading.current_thread().name)) 
    print('I can see a queue with size %d'%(sess.run(q))) 
    while not coord.should_stop(): 
     lock.acquire() 
     if sess.run(q) > 0: 
      filename, currepoch = sess.run([val, epoch]) 
      filename = filename.decode(encoding='UTF-8') 
      tid = threading.current_thread().name 
      print(filename + ' ' + str(currepoch) + ' thread ' + str(tid)) 
     elif sess.run(epoch) < maxepochs: 
      print('Thread %s has acquired the lock'%(threading.current_thread().name)) 
      print("The previous epoch = %d"%(sess.run(epoch))) 
      sess.run([incrementepoch, supplyimg]) 
      sz = sess.run(q) 
      print("The new epoch = %d"%(sess.run(epoch))) 
      print("The new queue size = %d"%(sz)) 
    else: 
      coord.request_stop() 
     lock.release() 

    return None 
1

我建議讓一個線程調用enqueue_many時代排列正確數量的圖像。它可以關閉隊列。這會讓你簡化你的工作函數和其他線程。

+0

謝謝,但我想使用多個線程來加速,因爲有我需要做的複雜預處理步驟 – Ujjwal

+0

您可以使用一個線程將文件名排入主隊列,然後使用多個線程將這些文件名出列,預處理,並將它們排入最終隊列。 –

1

我認爲GIL會阻止在這些線程中完成任何實際的並行操作。

要獲得tensorflow的性能,您需要保持數據的張量流。

Tensor Flow的reading data guide解釋瞭如何解決類似的問題。

更具體地說,你似乎已經重寫了一大段string_input_producer

+0

我正在使用實際數據。 'string_input_producer()'不會告訴它在任何時刻提取哪個數據。它只是讓任何一段數據都被提取出'epoch'次。所以,我的實現不是重寫'string_input_producer()'。我明白,我這樣做的方式可能不是最好的方式,但我需要對數據來源​​的特定時期和迭代進行直接和非常精確的檢查,並且我似乎沒有在數據指南中找到任何內容關於它。但我會再讀一遍。 – Ujjwal

+0

啊,謝謝澄清。我想我不完全明白這個問題。爲什麼不使用(字符串,時代)對? – mdaoust

+0

我已經使用它。我用'epoch'變量來計算'epoch' – Ujjwal