2016-12-01 63 views
2

我有一個關於通過tensorflow python api更新張量值的基本問題。更新tensorflow中的變量值

考慮代碼片段:

x = tf.placeholder(shape=(None,10), ...) 
y = tf.placeholder(shape=(None,), ...) 
W = tf.Variable(randn(10,10), dtype=tf.float32) 
yhat = tf.matmul(x, W) 

現在讓我們假設我想實現某種算法,用於迭代更新W的值(例如一些優化算法中)。這將涉及類似步:

for i in range(max_its): 
    resid = y_hat - y 
    W = f(W , resid) # some update 

這裏的問題是,W在LHS是一個新的張量,而不是在yhat = tf.matmul(x, W)使用W!也就是說,創建了一個新變量,並且我的「模型」中使用的W的值不會更新。

現在圍繞這一方法是

for i in range(max_its): 
    resid = y_hat - y 
    W = f(W , resid) # some update 
    yhat = tf.matmul(x, W) 

導致創建一個新的「模式」我的每個循環迭代的!

有沒有一種更好的方式來實現這個(在python中),而不是爲循環的每次迭代創建一大堆新模型 - 而是更新原始張量W「in-place」可以這麼說?

回答

2

變量有一個賦值方法。嘗試:W.assign(f(W,resid))

+0

這似乎工作!謝謝。 – firdaus

0

@ aarbelle的簡潔答案是正確的,我會擴展一下以防有人需要更多信息。最後2行用於更新W.

x = tf.placeholder(shape=(None,10), ...) 
y = tf.placeholder(shape=(None,), ...) 
W = tf.Variable(randn(10,10), dtype=tf.float32) 
yhat = tf.matmul(x, W) 

... 

for i in range(max_its): 
    resid = y_hat - y 
    update = W.assign(f(W , resid)) # do not forget to initialize tf variables. 
    # "update" above is just a tf op, you need to run the op to update W. 
    sess.run(update)