2016-07-26 74 views
0

的形狀,我測試下面的代碼腳本關於打印張

import tensorflow as tf 

a, b, c = 2, 3, 4 
x = tf.Variable(tf.random_normal([a, b, c], mean=0.0, stddev=1.0, dtype=tf.float32)) 
s = tf.shape(x) 
print(s) 

init = tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init) 
print(sess.run(s)) 

運行代碼得到如下結果

Tensor("Shape:0", shape=(3,), dtype=int32) 
[2 3 4] 

貌似只有第二打印給人的可讀格式。第一次印刷真的在做什麼或如何理解第一次輸出?

回答

1

s = tf.shape(x)的調用定義了一個符號(但非常簡單)的TensorFlow計算,只有在您致電sess.run(s)時才執行。

當您執行print(s) Python將打印TensorFlow知道張量s的所有內容,而無需實際評估它。由於它是tf.shape() op的輸出,因此TensorFlow知道它具有類型tf.int32,並且TensorFlow也可以推斷出它是長度爲3的向量(因爲x靜態地被稱爲來自變量定義的3-D張量)。

注意,在很多情況下,你就可以獲得更多的形狀信息而無需打印特定張的靜態形狀,使用Variable.get_shape()方法(類似於其表弟Tensor.get_shape())調用張量:

# Print the statically known shape of `x`. 
print(x.get_shape()) 
# ==> "(2, 3, 4)"