1
我有一個numpy的數據陣列,在這裏我只需要保留n
的最高值,並且零其他值爲零。保持numpy數組的每一行的n個最高值,並將其他所有值保存爲零
我目前的解決方案:
import numpy as np
np.random.seed(30)
# keep only the n highest values
n = 3
# Simple 2x5 data field for this example, real life application will be exteremely large
data = np.random.random((2,5))
#[[ 0.64414354 0.38074849 0.66304791 0.16365073 0.96260781]
# [ 0.34666184 0.99175099 0.2350579 0.58569427 0.4066901 ]]
# find indices of the n highest values per row
idx = np.argsort(data)[:,-n:]
#[[0 2 4]
# [4 3 1]]
# put those values back in a blank array
data_ = np.zeros(data.shape) # blank slate
for i in xrange(data.shape[0]):
data_[i,idx[i]] = data[i,idx[i]]
# Each row contains only the 3 highest values per row or the original data
#[[ 0.64414354 0. 0.66304791 0. 0.96260781]
# [ 0. 0.99175099 0. 0.58569427 0.4066901 ]]
在上面的代碼,data_
有n
最高值和其他一切歸零。即使data.shape[1]
小於n
,也可以很好地工作。但唯一的問題是for loop
,這很慢,因爲我的實際使用案例是在非常大的陣列上。
是否有可能擺脫for循環?
爲了清晰起見,我編輯了我的解決方案。使用上述數據,您的每行不會產生n個最高值。嘗試使用您的解決方案和相同的數據來查看差異。 – Fnord
@Fnord:oops,忘了額外的'argsort'。它需要另一個參數來讓它像rankdata一樣行事(我習慣於在Series或DataFrame上使用.rank('dense'))。 – DSM