2017-10-20 125 views
1

當我初始化並加載一個模型的權重時,輸出結果的準確率爲67%。Keras layer.set_weights不會修改圖層。爲什麼?

model.load_weights(path+'results/finetune_train_last_layer.h5') 
batches = model.get_batches(path, shuffle=False, batch_size=128, class_mode=None) 
preds = model.predict_generator(batches, batches.nb_sample) 
matches = 0 
for guess, ans in zip(np.argmax(preds, axis=1), batches.classes): 
    if guess == ans: 
     matches += 1 
print('%s/%s' % (matches, len(batches.classes))) 

532/792 

圖層正確加載。在我拯救他們之前,我對這些體重的最後一輪訓練的準確性是一樣的。

但是,當我嘗試使用與model中最後一層相同圖層的新模型並複製權重時,它們的權重不同。這怎麼可能?

no_drop_model = Sequential([ 
    MaxPooling2D(input_shape=(512, 14, 14)), 
    Flatten(), 
    Dense(4096, activation='relu'), 
    Dropout(0.), 
    Dense(4096, activation='relu'), 
    Dropout(0.), 
    Dense(120, activation='softmax') 
]) 
for ndl, fcl in zip(no_drop_model.layers, model.layers[31:]): 
    print(type(ndl), type(fcl)) 
    ndl.set_weights(fcl.get_weights()) 
    if ndl.get_weights(): 
     print(np.array_equiv(ndl.get_weights(), fcl.get_weights())) 

輸出:

(<class 'keras.layers.pooling.MaxPooling2D'>, <class 'keras.layers.pooling.MaxPooling2D'>) 
(<class 'keras.layers.core.Flatten'>, <class 'keras.layers.core.Flatten'>) 
(<class 'keras.layers.core.Dense'>, <class 'keras.layers.core.Dense'>) 
False 
(<class 'keras.layers.core.Dropout'>, <class 'keras.layers.core.Dropout'>) 
(<class 'keras.layers.core.Dense'>, <class 'keras.layers.core.Dense'>) 
False 
(<class 'keras.layers.core.Dropout'>, <class 'keras.layers.core.Dropout'>) 
(<class 'keras.layers.core.Dense'>, <class 'keras.layers.core.Dense'>) 
False 

回答

0

model.get_weights()的返回值是numpy的陣列不是一個單一的陣列的列表。您應該比較的權重是這樣的:

def create_model(): 
    i = Input((2,)) 
    o = Dense(3)(i) 
    return Model(i, o) 

model1 = create_model() 
model2 = create_model() 

for w1, w2 in zip(model1.get_weights(), model2.get_weights()): 
    print(np.array_equiv(w1, w2)) 

model2.set_weights(model1.get_weights()) 
print 'Weights after:' 
for w1, w2 in zip(model1.get_weights(), model2.get_weights()): 
    print(np.array_equiv(w1, w2)) 

這將產生以下輸出:

False 
True 
Weights after: 
True 
True 

在權重表的第二個元素對應於初始化爲零偏差值,所以值在複製權重之前相同。

+0

不錯!這樣做也更清潔。感謝您的迴應。 –

+0

請接受答案。 –

相關問題