2016-08-03 186 views
2

我有一個複雜的網絡的一個人爲的版本:爲什麼這個tensorflow循環需要這麼多的內存?

import tensorflow as tf 

a = tf.ones([1000]) 
b = tf.ones([1000]) 

for i in range(int(1e6)): 
    a = a * b 

我的直覺是,這應該需要很少的內存。只是初始數組分配的空間和一系列利用節點的命令並在每一步中覆蓋存儲在張量'a'中的內存。但內存使用增長相當迅速。

這裏發生了什麼,當我計算張量並覆蓋很多次時,如何減少內存使用量?

編輯:

由於雅羅斯拉夫的建議解決方案橫空出世使用while_loop以儘量減少對圖中的節點的數量要。這樣做效果很好,速度更快,所需內存更少,並且全部包含在圖表中。

import tensorflow as tf 

a = tf.ones([1000]) 
b = tf.ones([1000]) 

cond = lambda _i, _1, _2: tf.less(_i, int(1e6)) 
body = lambda _i, _a, _b: [tf.add(_i, 1), _a * _b, _b] 

i = tf.constant(0) 
output = tf.while_loop(cond, body, [i, a, b]) 

with tf.Session() as sess: 
    result = sess.run(output) 
    print(result) 

回答

6

a*b命令轉換爲tf.mul(a, b),這相當於tf.mul(a, b, g=tf.get_default_graph())。此命令將Mul節點添加到當前的Graph對象,因此您正嘗試將100萬個Mul節點添加到當前圖形。這也是有問題的,因爲你不能序列化大於2GB的Graph對象,所以一旦你處理這麼大的圖,有一些檢查可能會失敗。

我建議閱讀Programming Models for Deep Learning MXNet人。 TensorFlow是其術語中的「象徵性」編程,您將它視爲必不可少的。

爲了得到你想要使用Python循環中,您可以構建乘法運算一次,反覆運行它的東西,用feed_dict養活更新

mul_op = a*b 
result = sess.run(a) 
for i in range(int(1e6)): 
    result = sess.run(mul_op, feed_dict={a: result}) 

爲了更高的效率,你可以使用tf.Variable對象和var.assign避免Python的< - > TensorFlow數據傳輸

+0

這很有道理。我只想調用sess.run一次,並且一步處理從輸入到輸出的所有計算,因爲我正在反向傳播大圖。是否可以在不添加額外節點的情況下執行此循環? – jstaker7

+0

如果您在'tf.Variable'對象中將輸入保存到* b中,您將隔離它的依賴關係,因此您可以在該節點上執行'sess.run一百萬次,而無需評估圖中的任何其他內容 –

+0

不確定我很瞭解。假設'b'是一個可訓練的權重矩陣,在back-prop期間得到更新。如果我多次調用sess.run,不但會招致多次sess.run調用的額外開銷,而且會將該計算與梯度計算斷開連接,並需要做一些繁瑣的工作以確保其得到正確更新。這些假設是否正確?我想我希望有更好的方法來處理這個圖表。 – jstaker7