2017-06-03 77 views
0

我試圖將摘要添加到我的異步運行的TensorFlow圖。我已經在單線程的情況下運行了所有的東西,但是一旦我進入多線程,總結似乎就消失了。這裏有一個玩具的例子,我試圖做TensorFlow與線程總結

import tensorflow as tf # 1.1.0 
import threading 


class Worker: 
    def __init__(self): 
     self.x = tf.Variable([1, -2, 3], tf.float32, name='x') 
     self.y = tf.Variable([-1, 2, -3], tf.float32, name='y') 
     self.dot_product = tf.reduce_sum(tf.multiply(self.x, self.y)) 
     tf.summary.scalar("Dot_Product", self.dot_product) 

    def work(self): 
     for i in range(10): 
      SESS.run(self.dot_product) 

      # Write summary 
      summary_str = SESS.run(tf.summary.merge_all()) 
      WRITER.add_summary(summary_str, i) 
      WRITER.flush() 

COORD = tf.train.Coordinator() 
SESS = tf.Session() 
WRITER = tf.summary.FileWriter(SUMMARY_DIR, SESS.graph) 

# Single Thread case 
w = Worker() 
SESS.run(tf.global_variables_initializer()) 
print(tf.get_collection(tf.GraphKeys.SUMMARIES)) 
w.work() 

這工作正常。但是,如果我去的多線程:

# Multi-thread case 
workers = [Worker() for i in range(4)] 
SESS.run(tf.global_variables_initializer()) 
print(tf.get_collection(tf.GraphKeys.SUMMARIES)) 

worker_threads = [] 
for worker in workers: 
    job = lambda: worker.work() 
    t = threading.Thread(target=job) 
    t.start() 
    worker_threads.append(t) 
COORD.join(worker_threads) 

每當tf.summary.merge_all()叫我得到這樣的錯誤是由於事實,即它不能看到任何摘要:

Exception in thread Thread-2: 
Traceback (most recent call last): 
    File "/usr/lib/python3.5/threading.py", line 914, in _bootstrap_inner 
    self.run() 
    File "/usr/lib/python3.5/threading.py", line 862, in run 
    self._target(*self._args, **self._kwargs) 
    File "/home/anjum/PycharmProjects/junk.py", line 43, in <lambda> 
    job = lambda: worker.work() 
    File "/home/anjum/PycharmProjects/junk.py", line 22, in work 
    summary_str = SESS.run(tf.summary.merge_all()) 
    File "/usr/local/lib/python3.5/dist- 
packages/tensorflow/python/client/session.py", line 778, in run 
    run_metadata_ptr) 
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 969, in _run 
fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string) 
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 408, in __init__ 
self._fetch_mapper = _FetchMapper.for_fetch(fetches) 
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 227, in for_fetch 
(fetch, type(fetch))) 
TypeError: Fetch argument None has invalid type <class 'NoneType'> 

如果我把裏面print(tf.get_collection(tf.GraphKeys.SUMMARIES))work() ,返回一個空列表。所以這意味着我的總結在某處丟失了。

有人請解釋如何正確使用多線程摘要?

回答

0

我想我已經想通了 - 總結必須像這樣合併。我不是100%確定爲什麼TensorFlow對此非常挑剔

class Worker: 
    def __init__(self): 
     self.x = tf.Variable([1, -2, 3], tf.float32, name='x') 
     self.y = tf.Variable([-1, 2, -3], tf.float32, name='y') 
     self.dot_product = tf.reduce_sum(tf.multiply(self.x, self.y)) 
     tf.summary.scalar("Dot_Product", self.dot_product) 

     self.summarise = tf.summary.merge_all() 

    def work(self): 
     for i in range(10): 
      SESS.run(self.dot_product) 

      # Write summary 
      summary = SESS.run(self.summarise) 
      WRITER.add_summary(summary, i) 
      WRITER.flush()