2017-02-17 62 views
2

我的一位同事指出,當您需要屏蔽Keras中的非RNN輸入時,使用sample_weight代替屏蔽層的選項非常酷。Keras:爲非RNN屏蔽零填充輸入

就我而言,我在輸入中有62列,第63位是響應。前62列中97%以上的非零值包含在前30列中。我試圖讓這個工作,所以我想重量最後32列在訓練中0,本質上是創造一個'窮人的面具'。

這是一個8級分類任務,使用MLP。響應變量已使用Keras中的to_categorical()函數進行了轉換。

這裏的實現:

model = Sequential() 
model.add(Dense(100, input_dim=X.shape[1], init='uniform', activation='relu')) 
model.add(Dense(8, init='uniform', activation='sigmoid')) 
hist = model.fit(X, y, 
       validation_data=(X_test, ytest), 
       nb_epoch=epochs_, 
       batch_size=batch_size_, 
       callbacks=callbacks_list, 
       sample_weight = np.array([X.shape[1]-32, 30])) 

我得到這個錯誤:

in standardize_weights 
assert y.shape[:sample_weight.ndim] == sample_weight.shape 

如何解決我的sample_weight爲 '面具' 輸入的前32列?

回答

2

樣品重量是不工作這樣的:

sample_weight : optional array of the same length as x , containing weights to apply to the model's loss for each sample. In the case of temporal data, you can pass a 2D array with shape (samples, sequence_length) , to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile() . source

換句話說,該設置使上樣本訓練數據的,而不是在每個樣品的特徵的不同的權重。這僅用於訓練步驟。 我想你應該使用遮罩,如果你不希望圖層使用這些功能。或者只是從你的數據集中刪除它們?或者,如果它不太複雜,讓網絡自己學習哪些有用的功能。

這有幫助嗎?