2017-06-14 67 views
1

從我使用keras的fit_generator函數中的自定義生成函數返回。與keras`fit_generator()外形問題`我有與numpy的陣列的形狀,一個看似簡單的問題

發電機功能類似於此:

def data_generator(full_data, encoder): 
    for s in full_data: 
     in1_X = encoder.encode(s[:,0]) 
     in2_X = encoder.encode(s[:,1]) 
     out1_y = encoder.encode(s[:,2]) 
     out2_y = encoder.encode(s[:,3]) 
     X = [in1_X, in2_X] 
     y = [out1_y, out2_y] 
     yield (X,y) 

我可以通過使用for循環和打印的形狀,這只是返回(60,)

然而,在調用這個時候得到in1_X返回的形狀使用fit_generator()功能,它失敗:

train_data_gen = data_generator(full_data, encoder) 


main_in = Input(shape=(seq_len,), name='main_input') 

# ... 
# define model 
# ... 

joint_model.fit_generator(train_data_gen, steps_per_epoch=2000, epochs=2) 

從其中輸出是這樣的:

Error when checking input: 
expected main_input to have shape (None, 60) but got array with shape (60, 1) 

我怎樣才能得到這個不改變形狀(60,) numpy陣列形狀(60, 1)?其他人遇到過這個問題嗎?

+0

昂貴的管道膠帶是變平。 in1_x = in1_x.flatten()等等,因爲你的編碼器似乎返回一個二維數組。 – Uvar

+0

我們正面臨同樣的問題。扁平化並沒有什麼區別 - 它提供了與未扁平化相同的形狀。根據我們的經驗,我們調用發電機和評估形狀,並收到預期的一個:X,Y = my_generator.next() 打印( 「下一個[X]:{}」。格式(x.shape))#形狀(517,)在我們的例子中,但是當餵食時,我們得到錯誤「期望的main_input有形狀(無,517),但有形狀的數組(517,1)」 – kekec

回答