2017-08-24 60 views
1

(編輯原始問題) 在Tensorflow中,我們經常需要定義包含變量的函數,以在中間的網絡層上實現。有沒有評估此例如輸出的方式:Tensorflow中的函數

import tensorflow as tf 
def Mult(mult): 
    A = tf.get_variable([2,2], initializer = tf.zeros_initializer) 
    B = tf.get_variable([2,2], initializer = tf.zeros_initializer) 
    return mult*tf.matmul(A,B) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    # Here we alter the variables A,B in some fashion e.g. in an optimisation algorithm 
    print(Mult(A,B)) 

產生一個錯誤

回答

0

有三個錯誤在代碼:

  1. 如果你創建一個函數,接受

    , B作爲參數,在函數中創建具有相同名稱的變量沒有多大意義。因此,要麼刪除AB參數,要麼不要在函數內部創建變量。

  2. Mult(A,B)返回張量。要檢索張量的值,您需要在會話中對其進行評估。會話會跟蹤參數AB的值,從中可以計算出Mult(A,B)的值。

  3. tf.get_variable函數需要一個name參數。下面

代碼修復你的錯誤:

import tensorflow as tf 

def Mult(): 
    A = tf.get_variable('A', shape=[2,2], initializer = tf.zeros_initializer) 
    B = tf.get_variable('B', shape=[2,2], initializer = tf.zeros_initializer) 
    return tf.matmul(A,B) 

with tf.Session() as sess: 
    result = Mult() 
    sess.run(tf.global_variables_initializer()) 
    print(sess.run(result)) 
+0

爲什麼它產生一個錯誤,如果行:結果= ......,和sess.run(TF .....切換? – sbb

+0

'sess.run(tf.global_variables_initializer())'初始化圖形中的所有全局變量。如果在調用'Mult()'之前調用它,圖形中沒有變量,而您最終得到兩個單位變量('A'和'B')。 – GeertH

0

沒有與代碼的幾個問題。首先,由於您將變量AB傳遞給該函數,因此您無需在函數內對它們進行初始化。因此,功能應該是這樣的:

def Mult(A,B): 
    return tf.matmul(A,B) 

接下來,你需要初始化之前定義的變量,因此該行所定義的變量後

sess.run(tf.global_variables_initializer()) 

應該會出現。

另外,如果你想獲得的Mult(A,B)數值,應打印

sess.run(Mult(A,B)) 

並不僅僅是Mult(A,B)(因爲後者只會給你一個張量對象)。

最後,您需要爲您定義的變量提供一個名稱。 [2,2]是形狀(函數的第二個參數)。第一個參數應該是名稱。

這裏是校正後的代碼:

import tensorflow as tf 
def Mult(A,B): 
    return tf.matmul(A,B) 

A = tf.get_variable('A',shape=[2,2], initializer = tf.zeros_initializer) 
B = tf.get_variable('B',shape=[2,2], initializer = tf.zeros_initializer) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print(sess.run(Mult(A,B))) 

這將打印

[[ 0. 0.] 
[ 0. 0.]] 
相關問題