2016-03-09 100 views
3

我有一個ByteTensor,並希望抓住有1的指數。在numpy的,我可以做類似等同於np.where()的Lua Torch?

a = np.array([1,0,1,0,1]) 
return np.where(a) 

這將返回(array([0, 2, 4]),)。火炬中定義了這個功能嗎?

(在我的具體情況,我想用這些指標來索引到幾個不同的張量的對象,但它會是不錯的知道如何在一般的做到這一點。)

回答

5

您可以使用torch.nonzero,如:

> a = torch.ByteTensor{1,0,1,0,1} 
> print(torch.nonzero(a))                       
1                             
3                             
5                             
[torch.LongTensor of size 3x1] 

如果你真的需要找到1-S只有你能鏈中的邏輯運算符:

> a = torch.ByteTensor{1,2,1,6,1} 
> a:eq(1):nonzero()