0
我在添加對binary_crossentropy的懲罰時遇到了問題。當預定義的錯誤組的平均值違反某個閾值時,這個想法是懲罰損失函數。 以下是幫助函數,它用掩碼錶示組和已計算的crossentropy。它會簡單地返回違反某個閾值的次數來懲罰調用它的實際損失函數。Keras中的自定義丟失函數的問題
def penalty(groups_mask, binary_crossentropy):
errors = binary_crossentropy
unique_groups = set(groups_mask)
groups_mask = np.array(groups_mask)
threshold = # whatever
c = 0
for group in unique_groups:
error_mean = K.mean(errors[(groups_mask == group).nonzero()], axis=-1)
if error_mean > threshold:
c += 1
return c
麻煩的是,error_mean不是標量,我找不出一個簡單的方法來比較它的閾值。
我真的'不明白你想在這一行中實現什麼:'error_mean = K.mean(errors [(groups_mask == group).nonzero()],axis = -1)' –