2017-05-04 96 views
0

我正在努力計算一個簡單的二元分類問題中的類標籤,由2d-numpy數組給出每個類的概率。二進制分類中的概率閾值

例如:

prob_01 = 
array([[ 0.49253953, 0.50746047], 
     [ 0.01041495, 0.98958505], 
     [ 0.76774408, 0.23225592], 
     ..., 
     [ 0.79755047, 0.20244953], 
     [ 0.27228677, 0.72771323], 
     [ 0.26953926, 0.73046074]]) 

其中行是實例和列包含分別在類0和1是用於每個實例概率。例如,對於一個threshold = 0.5,一個應該得到:

labels_01= 
array([[ 1], 
     [ 1], 
     [ 0], 
     ..., 
     [ 1], 
     [ 0], 
     [ 0]]) 

什麼是產生labels_01陣列的最簡單和Python的方式?

回答

1

爲0級(第一列):

threshold = 0.5 
labels_01 = prob_01[:,0] < threshold 

真正得到整數代替布爾(假定import numpy as np):

labels_01 = (prob_01[:,0] < threshold).astype(np.int) 

或者只是使用

prob_01 < threshold 

到一次獲得兩列並稍後索引一列。

+0

哇,太簡單了!我完全錯過了只有一列。謝謝。 –