我的一位同事指出,當您需要屏蔽Keras中的非RNN輸入時,使用sample_weight
代替屏蔽層的選項非常酷。Keras:爲非RNN屏蔽零填充輸入
就我而言,我在輸入中有62列,第63位是響應。前62列中97%以上的非零值包含在前30列中。我試圖讓這個工作,所以我想重量最後32列在訓練中0,本質上是創造一個'窮人的面具'。
這是一個8級分類任務,使用MLP。響應變量已使用Keras中的to_categorical()
函數進行了轉換。
這裏的實現:
model = Sequential()
model.add(Dense(100, input_dim=X.shape[1], init='uniform', activation='relu'))
model.add(Dense(8, init='uniform', activation='sigmoid'))
hist = model.fit(X, y,
validation_data=(X_test, ytest),
nb_epoch=epochs_,
batch_size=batch_size_,
callbacks=callbacks_list,
sample_weight = np.array([X.shape[1]-32, 30]))
我得到這個錯誤:
in standardize_weights
assert y.shape[:sample_weight.ndim] == sample_weight.shape
如何解決我的sample_weight
爲 '面具' 輸入的前32列?