2017-08-01 100 views
0

我的問題是關於如何從多個(或分片)tfrecords獲取批量輸入。我已閱讀示例https://github.com/tensorflow/models/blob/master/inception/inception/image_processing.py#L410。基本流程是,以訓練集爲例,(1)首先從這些文件名中生成一系列記錄(例如,train-000-of-005,train-001-of-005,...),(2),生成一個列表並將它們送入tf.train.string_input_producer (3)同時生成一個tf.RandomShuffleQueue做其他的事情,(4)使用tf.train.batch_join來生成批量輸入。有沒有更簡單的方法來處理來自tfrecords的批量輸入?

我認爲這很複雜,我不確定這個過程的邏輯。在我的情況下,我有一個.npy文件列表,我想要生成分片tfrecords(多個分離的tfrecords,而不僅僅是一個大文件)。這些.npy文件中的每一個都包含不同數量的正面和負面樣本(2個類別)。一個基本的方法是生成一個大的tfrecord文件。但該文件太大(~20Gb)。所以我訴諸分片tfrecords。有沒有更簡單的方法來做到這一點?謝謝。

回答

11

整個過程使用Dataset API簡化。這裏有兩個部分:(1): Convert numpy array to tfrecords(2,3,4): read the tfrecords to generate batches。從numpy的陣列tfrecords的

1. 創建:

def npy_to_tfrecords(...): 
     # write records to a tfrecords file 
     writer = tf.python_io.TFRecordWriter(output_file) 

     # Loop through all the features you want to write 
     for ... : 
      let say X is of np.array([[...][...]]) 
      let say y is of np.array[[0/1]] 

     # Feature contains a map of string to feature proto objects 
     feature = {} 
     feature['X'] = tf.train.Feature(float_list=tf.train.FloatList(value=X.flatten())) 
     feature['y'] = tf.train.Feature(int64_list=tf.train.Int64List(value=y)) 

     # Construct the Example proto object 
     example = tf.train.Example(features=tf.train.Features(feature=feature)) 

     # Serialize the example to a string 
     serialized = example.SerializeToString() 

     # write the serialized objec to the disk 
     writer.write(serialized) 
     writer.close() 

2. 使用DataSet API閱讀tfrecords(tensorflow> = 1.2):

# Creates a dataset that reads all of the examples from filenames. 
    filenames = ["file1.tfrecord", "file2.tfrecord", ..."fileN.tfrecord"] 
    dataset = tf.contrib.data.TFRecordDataset(filenames) 

    # example proto decode 
    def _parse_function(example_proto): 
     keys_to_features = {'X':tf.FixedLenFeature((shape_of_npy_array), tf.float32), 
          'y': tf.FixedLenFeature((), tf.int64, default_value=0)} 
     parsed_features = tf.parse_single_example(example_proto, keys_to_features) 
    return parsed_features['X'], parsed_features['y'] 

    # Parse the record into tensors. 
    dataset = dataset.map(_parse_function) 

    # Shuffle the dataset 
    dataset = dataset.shuffle(buffer_size=10000) 

    # Repeat the input indefinitly 
    dataset = dataset.repeat() 

    # Generate batches 
    dataset = dataset.batch(batch_size) 

    # Create a one-shot iterator 
    iterator = dataset.make_one_shot_iterator() 

    # Get batch X and y 
    X, y = iterator.get_next() 
+0

啊,我非常感謝您的詳細回答!你救了我的命! – mining

+0

嗨,先生,這個api是否支持'tf.train.shuffle_batch' api中的'num_threads'或'capacity'?在我的情況下,如果網絡很小,那麼GPU中的執行速度要比數據加載速度快,這會導致GPU空閒時間。所以我想排隊取數據總是滿的。謝謝。 – mining

+2

檢查:https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#map –

相關問題