2017-07-19 142 views
0

我有標籤的NumPy的數組:IndexError索引的二維數組與一維數組(NumPy的)

labels = np.ndarray(10000, dtype=np.float32) 

在數組中的元素看起來像:

print(labels[1:5]) 
Output: [ 9. 9. 4. 1.] 

我想將它們轉換成一個熱編碼的標籤,我用下面的代碼:

one_hot_labels = np.eye(10)[labels] 

我得到以下錯誤:

IndexError  Traceback (most recent call last) 
<ipython-input-21-dccf85afc031> in <module>() 
    1 
----> 2 s=np.eye(10)[labels] 

IndexError: arrays used as indices must be of integer (or boolean) type 

我該如何解決這個問題?

+0

你確定標籤和火車標籤是一樣的嗎? –

+2

你需要使用整數值作爲索引:'one_hot_labels = np.eye(10)[labels.astype(int)]' – JohanL

+0

@JohanL謝謝。它的工作原理 – Jayanth

回答

2

您已將標籤定義爲np.float32。如果要將它們用作數組或矩陣的索引,則它們必須是整數。要轉換np.float32使用.astype(int)

one_hot_labels=np.eye(10)[labels.astype(int)] 

或整數直接定義標籤:

labels=np.ndarray(10000,dtype=int) 
+1

@Jayanth如果他回答了你的問題,請接受答案。 :) – SH7890

1

如果labelsfloat,你不希望改變其dtype,你可以簡單地使用MultiLabelBinarizer。這段代碼應該完成這項工作:

from sklearn.preprocessing import MultiLabelBinarizer 

mlb = MultiLabelBinarizer() 
one_hot_labels = mlb.fit_transform(labels[:, None])