2016-10-02 83 views
0

我定義批處理兩個正態分佈的:Tensorflow:正態分佈廣播

dist = tf.contrib.distributions.Normal(mu=[1., 2.], sigma=10.) 

然後我想evalutate每個上的每個點[0,1,2,3的該分佈的PDF。 ]。不幸的是

dist.pdf([0.0, 1.0, 2.0, 3.0]) 

做出了一個錯誤:

ValueError: Dimensions must be equal, but are 4 and 2 

如何評估它在一個簡單的方法,並有形狀的張量(2,4)作爲輸出?

回答

0

當您運行dist.prob([0.0, 1.0, 2.0, 3.0]) tensorflow嘗試以不同的正態分佈在列表中的每個條目處評估pdf,但批處理只有兩個。解決的辦法是在每個值來評估的PDF,然後堆疊張量一起:

dist = tf.contrib.distributions.Normal(loc=[1., 2.], scale=10.) 
tf.stack([dist.prob(m) for m in [0.0, 1.0, 2.0, 3.0]],axis=1) 

這產生了張量與所要求的形狀:

<tf.Tensor 'stack_4:0' shape=(2, 4) dtype=float32>