2017-04-09 115 views
0

我試圖在ResNet50上爲使用Keras 2.0的多分類任務添加Flatten圖層,密集圖層(relu)和密集圖層(softmax) 0.2 Theano 0.9.0 py2.7上Win10.Here是我的代碼:嘗試在ResNet50(notop)上添加Flatten圖層並獲取錯誤

def create_model(): 
    base_model = ResNet50(include_top=False, weights=None, 
          input_tensor=None, input_shape=(3,224,224), 
          pooling=None) 

    base_model.load_weights(weight_path+'/resnet50_weights_th_dim_ordering_th_kernels_notop.h5') 
    x = base_model.output 
    x = Flatten()(x) 
    x = Dense(128,activation='relu',kernel_initializer='random_normal', 
      kernel_regularizer=regularizers.l2(0.1), 
      activity_regularizer=regularizers.l2(0.1))(x) 

    x=Dropout(0.3)(x) 
    y = Dense(8, activation='softmax')(x) 
    model = Model(base_model.input, y) 
    for layer in base_model.layers: 
     layer.trainable = False 
    model.compile(optimizer='adadelta', 
    loss='categorical_crossentropy') 
    return model 

我已經設置image_dim_ordering:

from keras import backend as K 
K.set_image_dim_ordering('th') 

這裏是我的Keras.json文件:

{ 

"backend": "theano", ``"image_data_format": "channels_first", ``"epsilon": 1e-07, ``"floatx": "float32" }

以下是錯誤消息:

ValueError: The shape of the input to "Flatten" is not fully defined (got (2048, None, None). Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model. 
+0

什麼是錯誤堆棧跟蹤? – putonspectacles

+0

我可能應該提到,如果我不添加該行,那麼一切正常工作:'base_model.load_weights(weight_path +'/ resnet50_weights_th_dim_ordering_th_kernels_notop.' – JumpyWarlock

回答

0

你應該

傳遞一個input_shape參數傳遞給第一層。這是一個形狀元組(一個整數或無條目的元組,其中None表示可能會有任何正整數)。在input_shape中,不包含批次維度。

在你的情況,第一層是Flatten()圖層。它應該像

your_input = Input(shape=output_shape_of_resnet) 
x = Flatten(your_input) 

至於resnet50的輸出送入自己的圖層,可以考慮定義一個新的模式,結合自己的層和RESNET,像

new_model = Sequential() 
new_model.add(resnet_model) #Of course you need the definition and weights of resnet 
resnet_model.trainable = False #I guess? 
new_model.add(your_own_layers_model) 
0

我在情況下,一些錯誤,當輸入圖像的大小對於網絡模型而言太小時。如果圖層輸出數據的大小變爲0,則會出現此錯誤。您可以使用model.summary()來查看您的網絡外觀。這是model.summary()輸出的示例:

Layer (type)     Output Shape    Param # 
================================================================= 
conv2d_78 (Conv2D)   (None, 16, 21, 21)  160  
_________________________________________________________________ 
max_pooling2d_62 (MaxPooling (None, 16, 5, 5)   0   
_________________________________________________________________ 
... 
flatten_25 (Flatten)   (None, 32)    0   
_________________________________________________________________ 
dense_28 (Dense)    (None, 2)     1026  
================================================================= 
Total params: 31,970 
Trainable params: 31,970 
Non-trainable params: 0 
_________________________________________________________________ 
相關問題