2017-05-31 73 views
0

我是新手到tensorflow,我試圖獲得張量中最大值的索引。下面是代碼:沿多個維度的Tensorflow argmax

def select(input_layer): 

    shape = input_layer.get_shape().as_list() 

    rel = tf.nn.relu(input_layer) 
    print (rel) 
    redu = tf.reduce_sum(rel,3) 
    print (redu) 

    location2 = tf.argmax(redu, 1) 
    print (location2) 

sess = tf.InteractiveSession() 
I = tf.random_uniform([32, 3, 3, 5], minval = -541, maxval = 23, dtype = tf.float32) 
matI, matO = sess.run([I, select(I, 3)]) 
print(matI, matO) 

這裏是輸出:

Tensor("Relu:0", shape=(32, 3, 3, 5), dtype=float32) 
Tensor("Sum:0", shape=(32, 3, 3), dtype=float32) 
Tensor("ArgMax:0", shape=(32, 3), dtype=int64) 
... 

由於尺寸= 1在argmax功能的Tensor("ArgMax:0") = (32,3)形狀。有沒有辦法在應用argmax之前得到argmax輸出張量大小= (32,)而不是做reshape

+0

這有什麼錯了'tf.reshape(熱度,[32,-1])'? ['tf.argmax'](https://www.tensorflow.org/api_docs/python/tf/argmax)只會沿着一個軸減少 – martianwars

回答

1

您可能不希望輸出大小爲(32,),因爲當沿着幾個方向argmax時,通常希望所有縮小的維度都具有最大值的座標。在你的情況下,你會想要一個大小爲(32,2)的輸出。

你可以做一個二維argmax這樣的:

import numpy as np 
import tensorflow as tf 

x = np.zeros((10,9,8)) 
# pick a random position for each batch image that we set to 1 
pos = np.stack([np.random.randint(9,size=10), np.random.randint(8,size=10)]) 

posext = np.concatenate([np.expand_dims([i for i in range(10)], axis=0), pos]) 
x[tuple(posext)] = 1 

a = tf.argmax(tf.reshape(x, [10, -1]), axis=1) 
pos2 = tf.stack([a // 8, tf.mod(a, 8)]) # recovered positions, one per batch image 

sess = tf.InteractiveSession() 
# check that the recovered positions are as expected 
assert (pos == pos2.eval()).all(), "it did not work" 
+0

感謝您的快速響應,但問題是我正在使用tensorflow版本0.9並且tf.stack在該版本中不存在。你有替代品嗎? – Maystro

+0

對不起,我沒有版本0.9,tensorflow.org只有0.10以上的文檔。如果存在連接,可以將其與'expand_dims'結合起來以模擬'stack'。 – user1735003