2017-09-02 45 views
0

我想剪輯我的網絡中的鑑別器的所有訓練變量。集合中的Tensorflow剪輯值?

我得到鑑變量是這樣的:

A_d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'A_d_') 
B_d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'A_B_') 
discriminatorVars = self.A_d_vars + self.B_d_vars 

現在,如果我嘗試這樣做 discriminatorVars.assign(tf.clip_by_value(discriminatorVars, 0.01, 0.1))夾的所有值[0.01,0.1]它不會工作作爲瓦爾是python列表而不是張量。

我也嘗試過這一點,但它不工作:

self.sess.run(tf.map_fn(lambda var: var.assign(tf.clip_by_value(var, 0.01, 0.1)), var_list)) 

它說,list對象沒有assign方法。

當前我遍歷列表中的所有變量,並致電self.sess.run(var.assign(tf.clip_by_value(var, 0.01, 0.1)))
問題是,它非常緩慢。

如何批量更新集合以使其值被剪切?

回答

1

嘗試製作你想要做的分配操作的列表,並使用tf.grouphttps://www.tensorflow.org/api_docs/python/tf/group)對它們進行分組。將tf.group運營商通過sess.run

Session.run()可以有一個不平凡的開銷,所以你想在一個Session.run()調用中做所有的更新。

+0

感謝您的評論。我對tensorflow很新,你能否寫一個'tf.group'如何使用的例子? – Cristy