我是新手到tensorflow
,我試圖獲得張量中最大值的索引。下面是代碼:沿多個維度的Tensorflow argmax
def select(input_layer):
shape = input_layer.get_shape().as_list()
rel = tf.nn.relu(input_layer)
print (rel)
redu = tf.reduce_sum(rel,3)
print (redu)
location2 = tf.argmax(redu, 1)
print (location2)
sess = tf.InteractiveSession()
I = tf.random_uniform([32, 3, 3, 5], minval = -541, maxval = 23, dtype = tf.float32)
matI, matO = sess.run([I, select(I, 3)])
print(matI, matO)
這裏是輸出:
Tensor("Relu:0", shape=(32, 3, 3, 5), dtype=float32)
Tensor("Sum:0", shape=(32, 3, 3), dtype=float32)
Tensor("ArgMax:0", shape=(32, 3), dtype=int64)
...
由於尺寸= 1在argmax
功能的Tensor("ArgMax:0") = (32,3)
形狀。有沒有辦法在應用argmax
之前得到argmax
輸出張量大小= (32,)
而不是做reshape
?
這有什麼錯了'tf.reshape(熱度,[32,-1])'? ['tf.argmax'](https://www.tensorflow.org/api_docs/python/tf/argmax)只會沿着一個軸減少 – martianwars