2017-10-06 139 views
0

我在添加對binary_crossentropy的懲罰時遇到了問題。當預定義的錯誤組的平均值違反某個閾值時,這個想法是懲罰損失函數。 以下是幫助函數,它用掩碼錶示組和已計算的crossentropy。它會簡單地返回違反某個閾值的次數來懲罰調用它的實際損失函數。Keras中的自定義丟失函數的問題

def penalty(groups_mask, binary_crossentropy): 
    errors = binary_crossentropy 
    unique_groups = set(groups_mask) 
    groups_mask = np.array(groups_mask) 
    threshold = # whatever 
    c = 0 
    for group in unique_groups: 
     error_mean = K.mean(errors[(groups_mask == group).nonzero()], axis=-1) 
     if error_mean > threshold: 
     c += 1 
    return c 

麻煩的是,error_mean不是標量,我找不出一個簡單的方法來比較它的閾值。

+0

我真的'不明白你想在這一行中實現什麼:'error_mean = K.mean(errors [(groups_mask == group).nonzero()],axis = -1)' –

回答

2

必須使用張量和功能從keras backend

import keras.backend as K 

在錯誤的線所做的一切,你必須比較使用了這些功能太事情:

.... 
c = K.variable([0]) 
..... 
..... 
    errorGreater = K.cast(K.greater(error_mean,threshold), K.floatx()) 
    c+=K.max(errorGreater) #if error_mean is 1 element only, you can just c+=errorGreater. 
+0

我認爲theano允許你不使用'cast '。 –

相關問題