0
我想剪輯我的網絡中的鑑別器的所有訓練變量。集合中的Tensorflow剪輯值?
我得到鑑變量是這樣的:
A_d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'A_d_')
B_d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'A_B_')
discriminatorVars = self.A_d_vars + self.B_d_vars
現在,如果我嘗試這樣做 discriminatorVars.assign(tf.clip_by_value(discriminatorVars, 0.01, 0.1))
夾的所有值[0.01,0.1]它不會工作作爲瓦爾是python列表而不是張量。
我也嘗試過這一點,但它不工作:
self.sess.run(tf.map_fn(lambda var: var.assign(tf.clip_by_value(var, 0.01, 0.1)), var_list))
它說,list
對象沒有assign
方法。
當前我遍歷列表中的所有變量,並致電self.sess.run(var.assign(tf.clip_by_value(var, 0.01, 0.1)))
問題是,它非常緩慢。
如何批量更新集合以使其值被剪切?
感謝您的評論。我對tensorflow很新,你能否寫一個'tf.group'如何使用的例子? – Cristy