2017-08-10 69 views
0

我這裏定義一個類TensorFlow,如何重用一個變量的作用域名稱

class BasicNetwork(object): 
    def __init__(self, scope, task_name, is_train=False, img_shape=(80, 80)): 
     self.scope = scope 
     self.is_train = is_train 
     self.task_name = task_name 
     self.__create_network(scope, img_shape=img_shape) 

    def __create_network(self, scope, img_shape=(80, 80)): 
     with tf.variable_scope(scope): 
      with tf.variable_scope(self.task_name): 
       with tf.variable_scope('input_data'): 
        self.inputs = tf.placeholder(shape=[None, *img_shape, cfg.HIST_LEN], dtype=tf.float32) 
       with tf.variable_scope('networks'): 
        with tf.variable_scope('conv_1'): 
         self.conv_1 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.inputs, num_outputs=32, 
                kernel_size=[8, 8], stride=4, padding='SAME', trainable=self.is_train) 
        with tf.variable_scope('conv_2'): 
         self.conv_2 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_1, num_outputs=64, 
                kernel_size=[4, 4], stride=2, padding='SAME', trainable=self.is_train) 
        with tf.variable_scope('conv_3'): 
         self.conv_3 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_2, num_outputs=64, 
                kernel_size=[3, 3], stride=1, padding='SAME', trainable=self.is_train) 
        with tf.variable_scope('f_c'): 
         self.fc = slim.fully_connected(slim.flatten(self.conv_3), 512, 
                 activation_fn=tf.nn.elu, trainable=self.is_train) 

我想定義BasicNetwork的兩個實例與不同的任務名稱。範圍是'全球'。但是,當我檢查輸出,有

ipdb> for i in net_1.layres: print(i) 
Tensor("global/simple/networks/conv_1/Conv/Relu:0", shape=(?, 20, 20, 32), dtype=float32, device=/device:GPU:2) 
Tensor("global/simple/networks/conv_2/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2) 
Tensor("global/simple/networks/conv_3/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2) 
Tensor("global/simple/networks/f_c/fully_connected/Elu:0", shape=(?, 512), dtype=float32, device=/device:GPU:2) 

ipdb> for i in net_2.layres: print(i) 
Tensor("global_1/supreme/networks/conv_1/Conv/Relu:0", shape=(?, 20, 20, 32), dtype=float32, device=/device:GPU:2) 
Tensor("global_1/supreme/networks/conv_2/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2) 
Tensor("global_1/supreme/networks/conv_3/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2) 
Tensor("global_1/supreme/networks/f_c/fully_connected/Elu:0", shape=(?, 512), dtype=float32, device=/device:GPU:2) 

正如你可以在輸出中看到,一個新的範圍global_1已經建立,但我想讓它global。我設置了reuse=True,但後來我發現,當沒有名爲global的範圍時,reuse=True無法使用。我該怎麼辦?

回答

0

使用reuse確實可以得到現有的變量。現在重新使用變量軟管應該存在於圖中。如果存在同名變量,則可以將這些變量用於其他操作。

class BasicNetwork(object): 
def __init__(self, scope, task_name, reuse, is_train=False, img_shape=(80, 80)): 
    self.scope = scope 
    self.is_train = is_train 
    self.reuse = reuse 
    self.task_name = task_name 
    self.__create_network(scope, reuse=self.reuse, img_shape=img_shape) 

def __create_network(self, scope, reuse=None, img_shape=(80, 80)): 
    with tf.variable_scope(scope, reuse=reuse): 
    ... 
     # delete this line with tf.variable_scope(self.task_name): 
     # or replace with; with tf.name_scope(self.task_name):    

trainnet = BasicNetwork('global', taskname, None) 
# resue the created variables 
valnet = BasicNetwork('global', taskname, True) 
+0

嗨,任務名稱是不同的,一個是'簡單',另一個是'最高'。 –

+0

這將不會影響主要邏輯。我只是寫了同樣的插圖。 –

+0

對不起,但錯誤發生,'ValueError:Variable global/supreme/networks/conv_1/Conv/weights不存在,或者不是用tf.get_variable()創建的。你是否想在VarScope中設置重用=無? '。代碼是 'net_1 = BasicACNetwork('global','simple',reuse = None); net_2 = BasicACNetwork('global','supreme',reuse = True)' –

相關問題