2017-03-16 69 views
0

任何人都可以幫助弄清楚如何獲取一組嵌入?收集動態索引列表

我有一些代碼,預測各指標的概率,然後選擇最高:

# U is batch_size x max_sentence_length x embedding_size 
scores_per_index = find_start_preds(U ...) # batch_size x max_sentence_length x 1 
start_preds = tf.argmax(alpha, axis=1) # batch_size x 1 

我想,如果可能的話,再搶相關的每一個字的嵌入開始預測。那可能嗎?這是我在想什麼,但它不工作:(

u_s = U[:, start_preds, :] 

回答

1

你應該能夠使用tf.gather你想要的東西,而是讓你需要重新安排它僅適用於領先指標:

U2 = tf.transpose(U, [1, 0, 2]) 
u_s = tf.gather(U2, start_preds) 
u_s = tf.transpose(u_s, [1, 0, 2])