2017-10-28 108 views
0

我正在使用tensorflow存儲庫中image_retraining文件夾中提供的再培訓腳本。通過加載intermediate_output_graphs(.pb)繼續培訓(image_retraining/retrain.py)

之一解析器參數/標誌讓你存儲中間圖形每隔X步驟

parser.add_argument(
     '--intermediate_output_graphs_dir', 
     type=str, 
     default='tf_files2/tmp/intermediate_graph/', 
     help='Where to save the intermediate graphs.' 

然而,這似乎圖表存儲爲冷凍圖形與.pb擴展。 有關如何正確加載.pb文件以繼續培訓的信息非常少。 我發現的大多數信息都使用.meta圖和.ckpts。 .pb將被棄用?

如果是這樣,我應該重新從模型開始,並使用tf.Saver獲得 .meta和ckpt圖作爲中間檢查點?

昨天,我正在訓練一個模型,由於某種原因訓練凍結了,所以我想加載中間圖形,並繼續訓練。

我使用以來模型進行再培訓:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py

如果任何人都可以點我或告訴我如何正確地加載了一個.pb中間曲線(循序漸進),從我離開的地方繼續關 - 我真的很感激它。

謝謝。

編輯:

@Mingxing

所以我假設我應該讓retrain.py首先創建默認的圖形基於默認以來模型(下此功能),然後就用它覆蓋加載圖?

def create_model_graph(model_info): 
    """"Creates a graph from saved GraphDef file and returns a Graph object. 

    Args: 
    model_info: Dictionary containing information about the model architecture. 

    Returns: 
    Graph holding the trained Inception network, and various tensors we'll be 
    manipulating. 
    """ 
    with tf.Graph().as_default() as graph: 
    model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name']) 
    with gfile.FastGFile(model_path, 'rb') as f: 
     graph_def = tf.GraphDef() 
     graph_def.ParseFromString(f.read()) 
     bottleneck_tensor, resized_input_tensor = (tf.import_graph_def(
      graph_def, 
      name='', 
      return_elements=[ 
       model_info['bottleneck_tensor_name'], 
       model_info['resized_input_tensor_name'], 
      ])) 
    return graph, bottleneck_tensor, resized_input_tensor 

EDIT_2:

我得到一個錯誤是:

ValueError: Tensor("second_to_final_fC_layer_ops/weights/final_weights_1:0", shape=(2048, 102 
4), dtype=float32_ref) must be from the same graph as Tensor("BottleneckInputPlaceholder:0", 
shape=(?, 2048), dtype=float32). 

我第一FC層後添加一個額外的FC層。 因此2048 - > 1024 - >變異前的分類數目。

當訓練模型時,我沒有問題,但現在加載圖表我似乎遇到了上述錯誤。

這是所添加的層的外觀:

layer_name = 'second_to_final_fC_layer_ops' 
    with tf.name_scope(layer_name): 
    with tf.name_scope('weights'): 
     initial_value = tf.truncated_normal(
      [bottleneck_tensor_size, 1024], stddev=0.001) 

     layer_weights = tf.Variable(initial_value, name='weights') 

     variable_summaries(layer_weights) 
    with tf.name_scope('biases'): 
     layer_biases = tf.Variable(tf.zeros([1024]), name='biases') 
     variable_summaries(layer_biases) 
    with tf.name_scope('Wx_plus_b'): 
     logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases 
     tf.summary.histogram('pre_activations', logits) 
    with tf.name_scope('Relu_activation'): 
     relu_activated =tf.nn.relu(logits, name= 'Relu') 
     tf.summary.histogram('final_relu_activation', relu_activated) 

然後最後的層(其是原始最後的層,但現在的輸入是輸出來自最後一層,而不是瓶頸張量):

layer_name = 'final_training_ops' 
    with tf.name_scope(layer_name): 
    with tf.name_scope('weights'): 
     initial_value = tf.truncated_normal(
      [1024, class_count], stddev=0.001) 

     layer_weights = tf.Variable(initial_value, name='final_weights') 

     variable_summaries(layer_weights) 
    with tf.name_scope('biases'): 
     layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases') 
     variable_summaries(layer_biases) 
    with tf.name_scope('Wx_plus_b'): 
     logits = tf.matmul(relu_activated, layer_weights) + layer_biases 
     tf.summary.histogram('pre_activations', logits) 

    final_tensor = tf.nn.softmax(logits, name=final_tensor_name) 
    tf.summary.histogram('activations', final_tensor) 

編輯:還是不知道如何加載weights--加載圖形結構似乎很容易,但我不知道如何加載已使用的培訓,一旦盜夢空間的權重和投入再次使用轉移學習。

使用來自image_retraining/retrain.py的權重和變量的清晰示例將非常有用。謝謝。

回答

1

您可以使用tf.import_graph_def導入您的冷凍.pb文件:

# Read the .pb file into graph_def. 
with tf.gfile.GFile(FLAGS.graph, "rb") as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 

# Restore the graph. 
with tf.Graph().as_default() as graph: 
    tf.import_graph_def(graph_def, name="") 

# After this, graph is the what you need. 

雖然沒有什麼錯誤直接使用冷凍.pb文件,我還是要指出的是,推薦的方法是遵循標準保存/恢復(official doc)。

+0

謝謝。我用解決方案的另一個問題更新了OP。當你有時間的時候,請看看。謝謝=) – Moondra

+0

我創建了另一個關於我得到的錯誤的更新。不急。只需登錄更新。謝謝。 – Moondra