2017-10-14 60 views
0

我有ñ帶有佔位符的所有輸入的網絡,我想將它們全部鏈接到另一個佔位符(之後創建)作爲通用輸入。在張量流中分組的佔位符

class GroupOfNetworks(object): 
    def __init__(self,subtask_nets,ob_space): 
     self.x_inputs = [st_net.x for st_net in subtask_nets] #list of network inputs 

其中st_net.x是一個佔位符,聲明如下。

class Network(object): 
    def __init__(self, ob_space): 
      self.x = tf.placeholder(tf.float32, [None] + list(ob_space)) `#single network input 

我想有一個共同輸入到所有這些網絡的,所以我只需要有一鍵 - 值對我feed_dict。我試着在佔位符(下面的代碼片段)上做一個賦值操作,但是這會引發一個錯誤,因爲它們是張量而不是變量。

#in class GroupOfNetworks... 
common_x = tf.placeholder(tf.float32, [None] + list(ob_space),"common_input") 
set_input = tf.assign(self.x_inputs[0].x,common_x,"link_subtask_input") # DOES NOT WORK 

到目前爲止我使用一個程序生成feed_dict(如下所示),但這不是在圖形上的和在從一.meta文件加載圖形不能被導入。

def make_common_feed_dict(self,x): 
    return {placeholder:x for placeholder in self.x_inputs} 

有沒有人知道更好的解決方案?

回答

1

由於您需要爲網絡中的每個網絡使用一個佔位符(並且因此具有相同的輸入),因此只需使用相同的佔位符即可。

不是在對象__init__方法內創建佔位符,而是在外部創建它並將其傳遞給您創建的每個對象。 做這樣的事情:

# Define your network in this way 
class Network(object): 
    def __init__(self, placeholder): 
      self.x = placeholder 

然後,初始化Network對象定義佔位符前,然後用它

input_placeholder = tf.placeholder(tf.float32, [None] + list(ob_space)) 

network_a = Network(input_placeholder) 
network_b = Network(input_placeholder) 

他們,假設Network obejects得到了get方法來獲取輸出張量,你可以執行network_anetwork_b給他們提供相同的值:

sess.run([network_a.get(), network_b.get()], feed_dict={input_placeholder: value}) 
+0

是的,如果我不需要獨立運行'Network'對象並且它們可以被修改,就可以工作。您是否知道是否有辦法處理無法修改「網絡」對象但僅使用佔位符引用的情況? – yokian

相關問題