2017-06-19 109 views
2

我想使用來自Keras的預訓練Inception-V3模型,與來自Tensorflow的輸入管道(即通過張量輸入網絡輸入)配對。 這是我的代碼:Keras模型預測在使用張量輸入時發生變化

import tensorflow as tf 
from keras.preprocessing.image import load_img, img_to_array 
from keras.applications.inception_v3 import InceptionV3, decode_predictions, preprocess_input 
import numpy as np 

img_sample_filename = 'my_image.jpg' 
img = img_to_array(load_img(img_sample_filename, target_size=(299,299))) 
img = preprocess_input(img) 
img_tensor = tf.constant(img[None,:]) 

# WITH KERAS: 
model = InceptionV3() 
pred = model.predict(img[None,:]) 
pred = decode_predictions(np.asarray(pred)) #<------ correct prediction! 
print(pred) 

# WITH TF: 
model = InceptionV3(input_tensor=img_tensor) 
init = tf.global_variables_initializer() 

with tf.Session() as sess: 
    from keras import backend as K 
    K.set_session(sess) 

    sess.run(init) 
    pred = sess.run([model.output], feed_dict={K.learning_phase(): 0}) 

pred = decode_predictions(np.asarray(pred)[0]) 
print(pred)        #<------ wrong prediction! 

其中my_image.jpg是我要分類的任何圖像。

如果我用keras'predict函數來計算預測,結果是正確的。但是,如果我將張量從圖像陣列中取出並通過input_tensor=...將張量輸入到模型,然後通過sess.run([model.output], ...)計算預測結果是非常錯誤的。

不同行爲的原因是什麼?我不能以這種方式使用Keras網絡嗎?

回答

1

最後,通過InceptionV3代碼挖,我發現這個問題:sess.run(init)覆蓋在InceptionV3的構造函數加載weigts。 我發現這個問題的-dirty-修復是在sess.run(init)之後重新加載權重。

from keras.applications.inception_v3 import get_file, WEIGHTS_PATH 

with tf.Session() as sess: 
    from keras import backend as K 
    K.set_session(sess) 

    sess.run(init) 
    weights_path = get_file(
       'inception_v3_weights_tf_dim_ordering_tf_kernels.h5', 
       WEIGHTS_PATH, 
       cache_subdir='models', 
       md5_hash='9a0d58056eeedaa3f26cb7ebd46da564') 
    model.load_weights(weights_path) 
    pred = sess.run([model.output], feed_dict={K.learning_phase(): 0}) 

注意:爲get_file()的參數直接從InceptionV3的構造和拍攝,在我的例子,都僅限於與image_data_format='channels_last'還原整個網絡的權重。 我在this Github issue詢問是否有更好的解決方法。我會更新這個答案,如果我應該得到更多的信息。

+0

您可以始終初始化變量子集,而不是初始化每個變量(包括模型預先訓練的權重)。 – abhinavkulkarni

相關問題