2017-08-03 199 views
1

我已經訓練了一個模型並保存了檢查點。我的模型的代碼是:Tensorflow,如何添加一些新圖層時恢復模型?

with tf.variable_scope(scope): 
    self.inputs = tf.placeholder(shape=[None, 80, 80, 1], dtype=tf.float32) 
    self.conv_1 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.inputs, num_outputs=32, 
           kernel_size=[8, 8], stride=4, padding='SAME') 
    self.conv_2 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_1, num_outputs=64, 
           kernel_size=[4, 4], stride=2, padding='SAME') 
    self.conv_3 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_2, num_outputs=64, 
           kernel_size=[3, 3], stride=1, padding='SAME') 
    self.fc = slim.fully_connected(slim.flatten(self.conv_3), 512, activation_fn=tf.nn.elu) 

    # Output layers for policy and value estimations 
    self.policy = slim.fully_connected(self.fc, 
             cfg.ACTION_DIM, 
             activation_fn=tf.nn.softmax, 
             biases_initializer=None) 
    self.value = slim.fully_connected(self.fc, 
             1, 
             activation_fn=None, 
             biases_initializer=None) 

有同時運行約32處理,並且每個具有在上面的代碼中定義的全局網絡的一個副本,該scope是每個過程的ID。全球網絡的scopeglobal

之後,我想在self.fc圖層之後添加更多圖層。

with tf.variable_scope(scope): 
    self.inputs = tf.placeholder(shape=[None, 80, 80, 1], dtype=tf.float32) 
    self.conv_1 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.inputs, num_outputs=32, 
           kernel_size=[8, 8], stride=4, padding='SAME') 
    self.conv_2 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_1, num_outputs=64, 
           kernel_size=[4, 4], stride=2, padding='SAME') 
    self.conv_3 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_2, num_outputs=64, 
           kernel_size=[3, 3], stride=1, padding='SAME') 
    self.fc = slim.fully_connected(slim.flatten(self.conv_3), 512, activation_fn=tf.nn.elu) 

    # Output layers for policy and value estimations 
    self.policy = slim.fully_connected(self.fc, 
             cfg.ACTION_DIM, 
             activation_fn=tf.nn.softmax, 
             biases_initializer=None) 
    self.value = slim.fully_connected(self.fc, 
             1, 
             activation_fn=None, 
             biases_initializer=None) 

    self.new_fc_1 = slim.fully_connected(self.fc, 512, activation_fn=tf.nn.elu) 

然而,當我恢復模型,它報告了以下錯誤:

2017-08-03 22:23:43.473157: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint 
2017-08-03 22:23:43.477197: W tensorflow/core/framework/op_kernel.cc:1158] Data loss: Checksum does not match: stored 379803423 vs. calculated on the restored bytes 2648422677 
2017-08-03 22:23:43.477210: W tensorflow/core/framework/op_kernel.cc:1158] Data loss: Checksum does not match: stored 3963326522 vs. calculated on the restored bytes 3154501583 
2017-08-03 22:23:43.477200: W tensorflow/core/framework/op_kernel.cc:1158] Data loss: Checksum does not match: stored 3893236466 vs. calculated on the restored bytes 1767411214 
2017-08-03 22:23:43.478276: W tensorflow/core/framework/op_kernel.cc:1158] Data loss: Checksum does not match: stored 4239176201 vs. calculated on the restored bytes 3213118706 
2017-08-03 22:23:43.480438: W tensorflow/core/framework/op_kernel.cc:1158] Data loss: Checksum does not match: stored 442335910 vs. calculated on the restored bytes 4248164641 
2017-08-03 22:23:43.483885: W tensorflow/core/framework/op_kernel.cc:1158] Data loss: Checksum does not match: stored 3105262865 vs. calculated on the restored bytes 2648422677 
2017-08-03 22:23:43.483953: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint 
    [[Node: save/RestoreV2_128 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_128/tensor_names, save/RestoreV2_128/shape_and_slices)]] 
2017-08-03 22:23:43.486987: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint 
    [[Node: save/RestoreV2_128 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_128/tensor_names, save/RestoreV2_128/shape_and_slices)]] 
2017-08-03 22:23:43.490616: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint 
    [[Node: save/RestoreV2_128 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_128/tensor_names, save/RestoreV2_128/shape_and_slices)]] 
2017-08-03 22:23:43.491951: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint 
    [[Node: save/RestoreV2_128 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_128/tensor_names, save/RestoreV2_128/shape_and_slices)]] 
2017-08-03 22:23:43.491957: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint 
    [[Node: save/RestoreV2_128 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_128/tensor_names, save/RestoreV2_128/shape_and_slices)]] 
2017-08-03 22:23:43.494310: W tensorflow/core/framework/op_kernel.cc:1158] Not found: Key worker_15/fully_connected_3/weights not found in checkpoint 
    [[Node: save/RestoreV2_128 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_128/tensor_names, save/RestoreV2_128/shape_and_slices)]] 
.... .... 

我用下面的代碼保存模型

saver.save(sess, self.model_path+'/model-'+str(episode_count)+'.ckpt') 

這裏是定義代碼金丹

value_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/old_scope') 
value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/actor_critic')) 
saver = tf.train.Saver(value_list, max_to_keep=100) 

with tf.Session(config=tf_configs) as sess: 
    coord = tf.train.Coordinator() 
    if load_model: 
     print('Loading Model...') 
     ckpt = tf.train.get_checkpoint_state(model_path) 
     saver.restore(sess, ckpt.model_checkpoint_path) 
    else: 
     sess.run(tf.global_variables_initializer()) 

當一些具有隨機初始化參數的新層添加到當前神經網絡中時,如何恢復預先訓練的模型?

+0

使用舊模型恢復檢查點並在之後添加新張量 – aseipel

回答

0

google搜索了半天,並與@BlueSun的幫助後,我發現下面的方法可以幫助解決這個問題。

首先使用當前作用域中的變量在添加新作用域之前保存模型。

value_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/old_scope') 
value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/actor_critic')) 
saver = tf.train.Saver(value_list, max_to_keep=100) 

並培訓新網絡。

後來,加入新的範圍和運行模式,這樣

value_list = [] 
value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/old_scope')) 
value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/actor_critic')) 
value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global/added_layer')) 
saver = tf.train.Saver(value_list, max_to_keep=100) 

with tf.Session(config=tf_configs) as sess: 
    coord = tf.train.Coordinator() 
    if load_model: 
     print('Loading Model...') 
     ckpt = tf.train.get_checkpoint_state(model_path) 
     saver.restore(sess, ckpt.model_checkpoint_path) 
    else: 
     sess.run(tf.global_variables_initializer()) 

    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="global"), max_to_keep=100) 

代碼之前定義一個新的saver和網絡代碼看起來像這樣

with tf.variable_scope(scope): 
    with tf.variable_scope('old_scope'): 
     self.inputs = tf.placeholder(shape=[None, 80, 80, 1], dtype=tf.float32) 
     self.conv_1 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.inputs, num_outputs=32, 
            kernel_size=[8, 8], stride=4, padding='SAME') 
     self.conv_2 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_1, num_outputs=64, 
            kernel_size=[4, 4], stride=2, padding='SAME') 
     self.conv_3 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_2, num_outputs=64, 
            kernel_size=[3, 3], stride=1, padding='SAME') 
     self.fc = slim.fully_connected(slim.flatten(self.conv_3), 512, activation_fn=tf.nn.elu) 

    with tf.variable_scope('added_layer'): 
     self.fc_1 = slim.fully_connected(self.fc, 512, activation_fn=tf.nn.elu) 

    with tf.variable_scope('actor_critic'): 
     # Output layers for policy and value estimations 
     self.policy = slim.fully_connected(self.fc_1, 
             cfg.ACTION_DIM, 
             activation_fn=tf.nn.softmax, 
             biases_initializer=None) 
     self.value = slim.fully_connected(self.fc_1, 
              1, 
              activation_fn=None, 
              biases_initializer=None) 

現在工作得很好,雖然該代碼看起來有點不雅觀。

1

您可以使用兩個單獨的變量作用域。一個用於保存和加載,另一個用於新層。

然後,您可以指定保護程序只與變量從第一範圍工作:

saver = tf.train.Saver(
    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="save_scope") 
) 
+0

您能否提供更多示例? –

+0

'VALUE_LIST = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,範圍= '全局/ old_scope') value_list.extend(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,範圍= '全局/ actor_critic')) value_list.extend (tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope ='global/added_layer')) saver = tf.train.Saver(value_list,max_to_keep = 100)'我添加了一個新圖層,之後當我恢復模型時,它報告**在檢查點中找不到關鍵全局/ added_layer/fully_connected /權重** –

+0

「global/added_layer」作用域是否爲新層?如果是的話,爲什麼擴展你的value_list與added_layer範圍?您必須僅使用具有您嘗試加載的舊圖層的範圍來初始化保護程序,如'saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope =「old_scope」),max_to_keep = 100) '。否則,Saver將嘗試加載之前不存在的圖層。 – BlueSun