2017-12-27 96 views
3

只有存在時纔可以恢復變量嗎?這樣做的最習慣的方式是什麼?TensorFlow - 恢復(如果存在)

例如,考慮下面的小例子:

import tensorflow as tf 
import glob 
import sys 
import os 

with tf.variable_scope('volatile'): 
    x = tf.get_variable('x', initializer=0) 

with tf.variable_scope('persistent'): 
    y = tf.get_variable('y', initializer=0) 
    add1 = tf.assign_add(y, 1) 

saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'persistent')) 

sess = tf.InteractiveSession() 
tf.global_variables_initializer().run() 
tf.get_default_graph().finalize() 

print('save file', sys.argv[1]) 
if glob.glob(sys.argv[1] + '*'):  
    saver.restore(sess, sys.argv[1]) 

print(sess.run(y)) 
sess.run(add1) 
print(sess.run(y)) 
saver.save(sess, sys.argv[1]) 

當使用相同的參數運行兩次,程序首先打印0\n1然後1\n2預期。現在假設您通過在persistent範圍內的add1之後添加z = tf.get_variable('z', initializer=0)來更新您的代碼以具有新功能。再次運行這個時候,老保存文件存在將具有以下突破:

NotFoundError(見上文回溯):在檢查點沒有找到關鍵持久/ Z [節點:保存/ RestoreV2_1 = RestoreV2 [dtypes = [DT_INT32],_device =「/ job:localhost/replica:0/task:0/device:CPU:0」](_ arg_save/Const_0_0,save/RestoreV2_1/tensor_names,save/RestoreV2_1/shape_andslices)]] [ Node:save/Assign_1/_18 = _Recvclient_terminated = false,recv_device =「/ job:localhost/replica:0/task:0/device:GPU:0」,send_device =「/ job:localhost/replica:0/task:0/device:CPU:0「,send_device_incarnation = 1,tensor_name =」edge_12_save/Assign_1「,tensor_type = DT_FLOAT,_device =」/ job:localhost/replica:0/task:0/device:GPU:0「]]

回答

1

您可以使用下面的功能恢復(從here拍攝):

def optimistic_restore(session, save_file, graph=tf.get_default_graph()): 
    reader = tf.train.NewCheckpointReader(save_file) 
    saved_shapes = reader.get_variable_to_shape_map() 
    var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables() 
      if var.name.split(':')[0] in saved_shapes])  
    restore_vars = []  
    for var_name, saved_var_name in var_names:    
     curr_var = graph.get_tensor_by_name(var_name) 
     var_shape = curr_var.get_shape().as_list() 
     if var_shape == saved_shapes[saved_var_name]: 
      restore_vars.append(curr_var) 
    opt_saver = tf.train.Saver(restore_vars) 
    opt_saver.restore(session, save_file) 

我通常運行sess.run(tf.global_variables_initializer()),以確保所有的變量初始化,然後我跑optimistic_restore(sess,...)恢復它可以是變量恢復。

0

您可以創建兩個保存程序,一個用於恢復檢查點中的變量,另一個用於保存檢查點和新添加的變量中的變量。爲了恢復檢查點中的變量,綁定到保存器的所有變量應該在檢查點(這就是錯誤發生的原因),並且如果以這種方式執行操作,則應該先恢復所有變量,然後再恢復部分定義的變量存在於檢查點中。

相關問題