0
我在lua中有以下代碼writtein。從torch刪除項目。傳感器
我想從scores
及其相應的分數得到N個最高分數的索引。
它看起來像我將不得不迭代從scores
刪除當前的最大值,並再次檢索最大值,但找不到一個合適的方式來做到這一點。
nqs=dataset['question']:size(1);
scores=torch.Tensor(nqs,noutput);
qids=torch.LongTensor(nqs);
for i=1,nqs,batch_size do
xlua.progress(i, nqs)
r=math.min(i+batch_size-1,nqs);
scores[{{i,r},{}}],qids[{{i,r}}]=forward(i,r);
-- print(scores)
end
tmp,pred=torch.max(scores,2);
您能否詳細說明「從最後找到N個值的切片val和id」部分? – ytrewq
我喜歡val {{{1},{val:size(2)-N + 1,val:size(2)}}]我喜歡''並且與'id'相同,因爲'N'最大的元素是在排序張量的末尾。 – Ash
請注意,這不會解決重複的問題(我的意思是如果'scores'包含,* eg *,例如,100的最大值的兩倍),但我認爲這不是問題,因爲它不是問題在你的問題中提到。 – Ash