我想問如何實現它只是使用np.where
是一個位X/Y problem位。
所以我會盡力解釋我將如何優化這個功能。
我的第一直覺就是擺脫for
循環,這是痛點反正:
import numpy as np
from scipy.stats import logistic
def func1(y, X, thresholds):
ll = 0.0
for row in zip(y, X):
if row[0] == 0:
ll += logistic.logcdf(thresholds[0] - row[1])
elif row[0] == len(thresholds):
ll += logistic.logcdf(row[1] - thresholds[-1])
else:
diff_prob = logistic.cdf(thresholds[row[0]] - row[1]) - \
logistic.cdf(thresholds[row[0] - 1] - row[1])
diff_prob = 10 ** -5 if diff_prob < 10 ** -5 else diff_prob
ll += np.log(diff_prob)
return ll
y = np.array([0, 1, 2])
X = [2, 2, 2]
thresholds = np.array([2, 3])
print(func1(y, X, thresholds))
我剛纔所取代i
與row[0]
,在不改變環路的語義。所以這是一個更少的循環。
現在我想if-else
的不同分支中的語句的表格是相同的。爲此:
import numpy as np
from scipy.stats import logistic
def func2(y, X, thresholds):
ll = 0.0
for row in zip(y, X):
if row[0] == 0:
ll += logistic.logcdf(thresholds[0] - row[1])
elif row[0] == len(thresholds):
ll += logistic.logcdf(row[1] - thresholds[-1])
else:
ll += np.log(
np.maximum(
10 ** -5,
logistic.cdf(thresholds[row[0]] - row[1]) -
logistic.cdf(thresholds[row[0] - 1] - row[1])
)
)
return ll
y = np.array([0, 1, 2])
X = [2, 2, 2]
thresholds = np.array([2, 3])
print(func2(y, X, thresholds))
現在每個分支中的表達式的格式爲ll += expr
。
在此piont有優化可以採取幾個不同的路徑。你可以嘗試通過把它寫成理解來優化循環,但我懷疑它不會提高速度。
另一種途徑是將if
條件拉出循環。這就是與np.where
你的意圖是還有:
import numpy as np
from scipy.stats import logistic
def func3(y, X, thresholds):
y_0 = y == 0
y_end = y == len(thresholds)
y_rest = ~(y_0 | y_end)
ll_1 = logistic.logcdf(thresholds[0] - X[ y_0 ])
ll_2 = logistic.logcdf(X[ y_end ] - thresholds[-1])
ll_3 = np.log(
np.maximum(
10 ** -5,
logistic.cdf(thresholds[y[ y_rest ]] - X[ y_rest ]) -
logistic.cdf(thresholds[ y[y_rest] - 1 ] - X[ y_rest])
)
)
return np.sum(ll_1) + np.sum(ll_2) + np.sum(ll_3)
y = np.array([0, 1, 2])
X = np.array([2, 2, 2])
thresholds = np.array([2, 3])
print(func3(y, X, thresholds))
注意,我轉身X
爲np.array
,以便能夠在其上使用花哨的索引。
在這一點上,我敢打賭,它是足夠快,爲我的目的。但是,根據您的要求,您可以提前或超出此點。
在我的電腦,我得到如下結果:
y = np.random.random_integers(0, 10, size=(10000,))
X = np.random.random_integers(0, 10, size=(10000,))
thresholds = np.cumsum(np.random.rand(10))
%timeit func(y, X, thresholds) # Original
1 loops, best of 3: 1.51 s per loop
%timeit func1(y, X, thresholds) # Removed for-loop
1 loops, best of 3: 1.46 s per loop
%timeit func2(y, X, thresholds) # Standardized if statements
1 loops, best of 3: 1.5 s per loop
%timeit func3(y, X, thresholds) # Vectorized ~ 500x improvement
100 loops, best of 3: 2.74 ms per loop
非常感謝你。我從你的代碼中學到很多東西。 – PENG