2017-06-13 137 views
0

我正在使用張量流的imageNet訓練模型來分類圖像的多個類別。ValueError:GraphDef不能大於2GB

我編輯的腳本classify.py作爲

import tensorflow as tf 
import sys 
import glob 
import os 
import pandas as pd 

# Disable tensorflow compilation warnings 
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 
import tensorflow as tf 

test_path = '/Users/kaustubhmundra/Desktop/Multi-Class Classifier/test' 

classes = ['room','reception','washroom','facade'] 

result = pd.DataFrame(columns = ['facade','washroom','room','reception']) 

def predict(image_path): 
    #image_path = sys.argv[1] 

    # Read the image_data 
    image_data = tf.gfile.FastGFile(image_path, 'rb').read() 

    # Loads label file, strips off carriage return 
    label_lines = [line.rstrip() for line 
         in tf.gfile.GFile("tf_files/retrained_labels.txt")] 

    # Unpersists graph from file 
    with tf.gfile.FastGFile("tf_files/retrained_graph.pb", 'rb') as f: 
     graph_def = tf.GraphDef() 
     graph_def.ParseFromString(f.read()) 
     _ = tf.import_graph_def(graph_def, name='') 

    with tf.Session() as sess: 
     # Feed the image_data as input to the graph and get first prediction 
     softmax_tensor = sess.graph.get_tensor_by_name('final_result:0') 

     predictions = sess.run(softmax_tensor, \ 
       {'DecodeJpeg/contents:0': image_data}) 

     # print(predictions) 

     pred = pd.DataFrame(predictions,columns = ['facade','washroom','room','reception']) 

     # print(pred) 

     global result 

     result = result.append(pred) 

     # print(result) 

     # Sort to show labels of first prediction in order of confidence 
     top_k = predictions[0].argsort()[-len(predictions[0]):][::-1] 

     for node_id in top_k: 
      human_string = label_lines[node_id] 
      score = predictions[0][node_id] 
      print('%s (score = %.5f)' % (human_string, score)) 



path = os.path.join(test_path, '*') 
files = sorted(glob.glob(path)) 

i=1 

for fl in files: 
    print(i) 
    i = i + 1 
    predict(fl) 

result.to_csv('predictions.csv') 

雖然我用它來預測上的圖像,它完美的作品,直到24倍的圖像,但隨後顯示了一個錯誤:

File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2154, in _as_graph_def raise ValueError("GraphDef cannot be larger than 2GB.") ValueError: GraphDef cannot be larger than 2GB.

我如何解決這個問題?

回答

0

您每次調用predict()時都會導入圖表,因此您正在累積非常大的默認graphdef。您應該更改代碼,以便僅在預測函數之外加載一次圖形(「文件中的#Unpersists圖形」部分)。這也可以大大加快你的代碼。

+0

非常感謝! 這個工程令人驚訝。這很簡單,我不知道爲什麼它沒有打我。 :) –

相關問題