2017-07-29 69 views
0

我在lua中有以下代碼writtein。從torch刪除項目。傳感器

我想從scores及其相應的分數得到N個最高分數的索引。

它看起來像我將不得不迭代從scores刪除當前的最大值,並再次檢索最大值,但找不到一個合適的方式來做到這一點。

nqs=dataset['question']:size(1); 
scores=torch.Tensor(nqs,noutput); 
qids=torch.LongTensor(nqs); 
for i=1,nqs,batch_size do 
    xlua.progress(i, nqs) 
    r=math.min(i+batch_size-1,nqs); 
    scores[{{i,r},{}}],qids[{{i,r}}]=forward(i,r); 
-- print(scores) 
end 

tmp,pred=torch.max(scores,2); 

回答

1

我希望我沒有誤會,因爲你展示的代碼(尤其是福爾循環)並沒有真正似乎相關要你想做的事情。無論如何,這是我該怎麼做。

sr=scores:view(-1,scores:size(1)*scores:size(2)) 
val,id=sr:sort() 
--val is a row vector with the values stored in increasing order 
--id will be the corresponding index in sr 
--now you can slice val and id from the end to find the N values you want, then you can recover the original index in the scores matrix simply with 
col=(index-1)%scores:size(2)+1 
row=math.ceil(index/scores:size(2)) 

希望這有助於。

+0

您能否詳細說明「從最後找到N個值的切片val和id」部分? – ytrewq

+0

我喜歡val {{{1},{val:size(2)-N + 1,val:size(2)}}]我喜歡''並且與'id'相同,因爲'N'最大的元素是在排序張量的末尾。 – Ash

+0

請注意,這不會解決重複的問題(我的意思是如果'scores'包含,* eg *,例如,100的最大值的兩倍),但我認爲這不是問題,因爲它不是問題在你的問題中提到。 – Ash