2017-09-01 97 views
-1

我想知道如何在3D數組中使用tf.argmax如何在3d數組中使用argmax tensorflow函數?

我輸入的數據是這樣的:

[[[0, -1, 5, 2, 1], [2, 2, 3, 2, 5], [6, 1, 2, 4, -1]], 
[[-1, -2, 3, 2, 1], [0, 3, 2, 7, -1], [-1, 5, 2, 1, 3]]] 

而且我想通過這樣的該輸入數據得到argmax的輸出:

[[2, 4, 0], [2, 3, 1]] 

而且我想用softmax_cross_entropy_with_logits功能在這格式。

我應該如何使用tf.nn.softmax_cross_entropy_with_logits函數和tf.equal(tf.argmax)tf.reduce_mean(tf.cast)

回答

0

您可以使用tf.argmax沿axis=3

a = tf.constant([[[0, -1, 5, 2, 1], [2, 2, 3, 2, 5], [6, 1, 2, 4, -1]], 
     [[-1, -2, 3, 2, 1], [0, 3, 2, 7, -1], [-1, 5, 2, 1, 3]]]) 
b = tf.argmax(a, axis=2)