2017-05-30 168 views
0

我有一個很深的CNN,適用於多類分類。我想「升級」挑戰並在多標籤分類問題上進行培訓。多標籤分類:如何學習閾值?

要做到這一點我代替我的SOFTMAX通過乙狀結腸,試圖培養我的網絡,以儘量減少:

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=y_pred) 

但我結束了怪異的預測:

Prediction for Im1 : [ 0.59275776 0.08751075 0.37567005 0.1636796 0.42361438 0.08701646 

0.38991812 0.54468459 0.34593087 0.82790571]

Prediction for Im1 : [ 0.52609032 0.07885984 0.45780018 0.04995904 0.32828355 0.07349177 

0.35400775 0.36479294 0.30002621 0.84438241]

Prediction for Im1 : [ 0.58714485 0.03258472 0.3349618 0.03199361 0.54665488 0.02271551 

0.43719986 0.54638696 0.20344526 0.88144571]

所以我想我儘量讓我的網絡學習門檻爲每個類,以確定樣本屬於或不給TE類。

所以我將此添加到我的代碼:

initial = tf.truncated_normal([numberOfClasses], stddev=0.1) 
W_thresh = tf.Variable(initial) 

y_predict_thresh = int(y_predict > W_thresh) 

但是我有一個錯誤:

TypeError: int() argument must be a string or a number, not 'Tensor'. 

任何人有任何想法,以幫助我前進(如何避免這種錯誤,是我的數據集實際上是不平衡的這一事實導致了這種「不變的」預測?對於多標籤分類的其他建議?,...)?

謝謝

編輯:我只是意識到,這樣做的閾值可能不是反傳真的很酷:/

回答

2

不知道如果你仍然需要它,但你可以使用tensorflow轉換功能tf.to_int32, tf.to_int64。在評估之前,你的表達式是python的一個對象,因此它不能簡單地將它轉換爲int()

這確實你需要:

with tf.Session() as sess: 
    check = sess.run([tf.to_int64(W1 > W2)]) 
+0

謝謝你來回ansewer。我實際上意識到我的問題的解決方案不是門檻,而是解決了一個事實,即我有一個不平衡的數據集。但是你的回答對我的「門檻」問題很好。再次感謝你:) –

+0

不客氣! ;-) – gionni