2017-12-18 153 views
7

我正在嘗試使用我自己的數據集和類在imagenet pretrained傳輸學習的Inception-resnet v2模型。 我的原始代碼庫是一個tf.slim樣本的修改,我找不到了,現在我試圖用tf.estimator.*框架重寫相同的代碼。傳輸學習與tf.estimator.Estimator框架

但是,我正在運行,只能加載一些的權重從預訓練檢查點,初始化其餘層與他們的默認初始值設定項。

研究這個問題,我發現this GitHub issuethis question,都提到需要在我的model_fn中使用tf.train.init_from_checkpoint。我試過了,但由於缺乏兩個例子,我想我錯了。

這是我的小例子:

import sys 
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 
import tensorflow as tf 
import numpy as np 

import inception_resnet_v2 

NUM_CLASSES = 900 
IMAGE_SIZE = 299 

def input_fn(mode, num_classes, batch_size=1): 
    # some code that loads images, reshapes them to 299x299x3 and batches them 
    return tf.constant(np.zeros([batch_size, 299, 299, 3], np.float32)), tf.one_hot(tf.constant(np.zeros([batch_size], np.int32)), NUM_CLASSES) 


def model_fn(images, labels, num_classes, mode): 
    with tf.contrib.slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()): 
    logits, end_points = inception_resnet_v2.inception_resnet_v2(images, 
              num_classes, 
              is_training=(mode==tf.estimator.ModeKeys.TRAIN)) 
    predictions = { 
     'classes': tf.argmax(input=logits, axis=1), 
     'probabilities': tf.nn.softmax(logits, name='softmax_tensor') 
    } 

    if mode == tf.estimator.ModeKeys.PREDICT: 
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) 

    exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits'] 
    variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude) 
    scopes = { os.path.dirname(v.name) for v in variables_to_restore } 
    tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt', 
           {s+'/':s+'/' for s in scopes}) 

    tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits) 
    total_loss = tf.losses.get_total_loss() #obtain the regularization losses as well 

    # Configure the training op 
    if mode == tf.estimator.ModeKeys.TRAIN: 
    global_step = tf.train.get_or_create_global_step() 
    optimizer = tf.train.AdamOptimizer(learning_rate=0.00002) 
    train_op = optimizer.minimize(total_loss, global_step) 
    else: 
    train_op = None 

    return tf.estimator.EstimatorSpec(
    mode=mode, 
    predictions=predictions, 
    loss=total_loss, 
    train_op=train_op) 

def main(unused_argv): 
    # Create the Estimator 
    classifier = tf.estimator.Estimator(
     model_fn=lambda features, labels, mode: model_fn(features, labels, NUM_CLASSES, mode), 
     model_dir='model/MCVE') 

    # Train the model 
    classifier.train(
     input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, NUM_CLASSES, batch_size=1), 
     steps=1000) 

    # Evaluate the model and print results 
    eval_results = classifier.evaluate(
     input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL, NUM_CLASSES, batch_size=1)) 
    print() 
    print('Evaluation results:\n %s' % eval_results) 

if __name__ == '__main__': 
    tf.app.run(main=main, argv=[sys.argv[0]]) 

其中inception_resnet_v2the model implementation in Tensorflow's models repository

如果我運行這個腳本,我會從init_from_checkpoint得到一堆信息日誌,但是在會話創建時,它似乎嘗試從檢查點加載Logits權重,並因形狀不兼容而失敗。這是完整的回溯:

Traceback (most recent call last): 

    File "<ipython-input-6-06fadd69ae8f>", line 1, in <module> 
    runfile('C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py', wdir='C:/Users/1/Desktop/transfer_learning_tutorial-master') 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile 
    execfile(filename, namespace) 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile 
    exec(compile(f.read(), filename, 'exec'), namespace) 

    File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 77, in <module> 
    tf.app.run(main=main, argv=[sys.argv[0]]) 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 48, in run 
    _sys.exit(main(_sys.argv[:1] + flags_passthrough)) 

    File "C:/Users/1/Desktop/transfer_learning_tutorial-master/MCVE.py", line 68, in main 
    steps=1000) 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 302, in train 
    loss = self._train_model(input_fn, hooks, saving_listeners) 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 780, in _train_model 
    log_step_count_steps=self._config.log_step_count_steps) as mon_sess: 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 368, in MonitoredTrainingSession 
    stop_grace_period_secs=stop_grace_period_secs) 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 673, in __init__ 
    stop_grace_period_secs=stop_grace_period_secs) 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 493, in __init__ 
    self._sess = _RecoverableSession(self._coordinated_creator) 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 851, in __init__ 
    _WrappedSession.__init__(self, self._create_session()) 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 856, in _create_session 
    return self._sess_creator.create_session() 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 554, in create_session 
    self.tf_sess = self._session_creator.create_session() 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\monitored_session.py", line 428, in create_session 
    init_fn=self._scaffold.init_fn) 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\session_manager.py", line 279, in prepare_session 
    sess.run(init_op, feed_dict=init_feed_dict) 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 889, in run 
    run_metadata_ptr) 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run 
    feed_dict_tensor, options, run_metadata) 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1317, in _do_run 
    options, run_metadata) 

    File "C:\Users\1\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1336, in _do_call 
    raise type(e)(node_def, op, message) 

InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [900] rhs shape= [1001] [[Node: Assign_1145 = Assign[T=DT_FLOAT, 
_class=["loc:@InceptionResnetV2/Logits/Logits/biases"], use_locking=true, validate_shape=true, 
_device="/job:localhost/replica:0/task:0/device:CPU:0"](InceptionResnetV2/Logits/Logits/biases, checkpoint_initializer_1145)]] 

我在做什麼錯誤時使用init_from_checkpoint?我們究竟應該如何在我們的model_fn中「使用」它?爲什麼當我明確告訴它不要時,估計器試圖從檢查點加載Logits'權重?

更新:

的意見建議後,我想其他方法來調用tf.train.init_from_checkpoint

使用{v.name: v.name}

如果,如評論所說,我替換{v.name:v.name for v in variables_to_restore}電話,我得到這個錯誤:

ValueError: Assignment map with scope only name InceptionResnetV2/Conv2d_2a_3x3 should map 
to scope only InceptionResnetV2/Conv2d_2a_3x3/weights:0. Should be 'scope/': 'other_scope/'. 

使用{v.name: v}

相反,如果我嘗試使用name:variable映射,我得到以下錯誤:

ValueError: Tensor InceptionResnetV2/Conv2d_2a_3x3/weights:0 is not found in 
inception_resnet_v2_2016_08_30.ckpt checkpoint 
{'InceptionResnetV2/Repeat_2/block8_4/Branch_1/Conv2d_0c_3x1/BatchNorm/moving_mean': [256], 
'InceptionResnetV2/Repeat/block35_9/Branch_0/Conv2d_1x1/BatchNorm/beta': [32], ... 

錯誤繼續列出我認爲所有檢查點中的變量名稱(或者它可能是範圍?)。

更新(2)

上方這裏檢查最新的錯誤後,我看到InceptionResnetV2/Conv2d_2a_3x3/weights是在檢查點變量列表問題在於末尾:0 我現在要驗證這是否確實解決了問題併發布了答案(如果是這種情況)。

+0

是否有任何檢查點的估計目錄'模型/ MCVE'? –

+0

不,目錄爲空 – GPhilo

+0

或許'scopes = {os.path.dirname(v.name)for v in variables_to_restore}'是將InceptionResnetV2添加到作用域列表中,所以所有的變量在InceptionResnetV2 /將被加載。你可以嘗試直接列出變量,而不是構建一個範圍列表:'tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt', {v.name:v.name for v in variables})' –

回答

5

感謝@ KathyWu的評論,我找到了正確的道路,發現了問題。

事實上,我被計算scopes將包括InceptionResnetV2/範圍的方式,將觸發「之下」的範圍所有變量的負荷(即,網絡中的所有變量)。然而,用正確的詞典代替它並不是微不足道的。

可能範圍模式init_from_checkpoint accepts中,有一個我不得不使用是'scope_variable_name': variable之一,但不使用實際的variable.name屬性

variable.name看起來像:'some_scope/variable_name:0':0不在檢查點變量的名稱中,因此使用scopes = {v.name:v.name for v in variables_to_restore}將引發「變量未找到」錯誤。

訣竅,使其工作從名字剝張量指數:

tf.train.init_from_checkpoint('inception_resnet_v2_2016_08_30.ckpt', 
           {v.name.split(':')[0]: v for v in variables_to_restore})