2017-10-06 71 views
0

繼續從this問題和討論here - 我試圖使用數據集API獲取變長度張量的數據集,並將它們切成長度相等的切片(片段)。喜歡的東西:從tfrecords數據集生成stridded slice的數據集

Dataset = tf.contrib.data.Dataset 
segment_len = 6 
batch_size = 16 

with tf.Graph().as_default() as g: 
    # get the tfrecords dataset 
    dataset = tf.contrib.data.TFRecordDataset(filenames).map(
     partial(record_type.parse_single_example, graph=g)).batch(batch_size) 
    # zip it with the number of segments we need to slice each tensor 
    dataset2 = Dataset.zip((dataset, Dataset.from_tensor_slices(
     tf.constant(num_segments, dtype=tf.int64)))) 
    it2 = dataset2.make_initializable_iterator() 
    def _dataset_generator(): 
     with g.as_default(): 
      while True: 
       try: 
        (im, length), count = sess.run(it2.get_next()) 
        dataset3 = Dataset.zip((
         # repeat each tensor then use map to take a stridded slice 
         Dataset.from_tensors((im, length)).repeat(count), 
         Dataset.range(count))).map(lambda x, c: (
          x[0][:, c: c + segment_len], 
          x[0][:, c + 1: (c + 1) + segment_len], 
        )) 
        it = dataset3.make_initializable_iterator() 
        it_init = it.initializer 
        try: 
         yield it_init 
         while True: 
          yield sess.run(it.get_next()) 
        except tf.errors.OutOfRangeError: 
         continue 
       except tf.errors.OutOfRangeError: 
        return 
    # Dataset.from_generator need tensorflow > 1.3 ! 
    das_dataset = Dataset.from_generator(
     _dataset_generator, 
     (tf.float32, tf.float32), 
     # (tf.TensorShape([]), tf.TensorShape([])) 
    ) 
    das_dataset_it = das_dataset.make_one_shot_iterator() 


with tf.Session(graph=g) as sess: 
    while True: 
     print(sess.run(it2.initializer)) 
     print(sess.run(das_dataset_it.get_next())) 

當然我並不想通過發電機的會議,但這應該通過在鏈接中給出的伎倆被workarounded(創建一個虛擬數據集,並映射其他的迭代器)。上面的代碼失敗與聖經:

tensorflow.python.framework.errors_impl.InvalidArgumentError: TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class 'tensorflow.python.framework.ops.Operation'>. 
     [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_FLOAT, DT_FLOAT], token="pyfunc_1"](arg0)]] 
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[<unknown>, <unknown>], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]] 

這是我猜是因爲我設法得到迭代器的初始化,但我的問題基本上是,如果我可以在任何我所使用的數據集API努力實現。

回答

1

從嵌套Dataset構建Dataset的最簡單方法是使用Dataset.flat_map()轉換。此轉換對輸入數據集的每個元素應用了一個函數(在您的示例中爲dataset2),該函數返回嵌套的Dataset(在您的示例中最可能爲dataset3),然後該轉換將所有嵌套數據集平展爲單個Dataset

dataset2 = ... # As above. 

def get_slices(im_and_length, count): 
    im, length = im_and_length 
    # Repeat each tensor then use map to take a strided slice. 
    return Dataset.zip((
     Dataset.from_tensors((im, length)).repeat(count), 
     Dataset.range(count))).map(lambda x, c: (
      x[0][:, c + segment_len: (c + 1) + segment_len], 
      x[0][:, c + 1 + segment_len: (c + 2) + segment_len], 
)) 

das_dataset = dataset2.flat_map(get_slices) 
+0

優秀的感謝 - 它沒有發生在我flat_map是的工具的工作 –

+0

FYI這也許不MonitoredTrainingSession發揮好 - 迭代器有時會出現異常先進的(因爲它註定要像模型摘要成本?) - 或者我可能是錯的,這是我的錯。將不得不進行更多的調查,但同時,因爲在github上討論了MonitoredTrainingSession和數據集的集成,所以我只是注意到,所以你也要記住 - 也就是說,我們必須至少警告人們小心地推進隱藏在操作中的迭代器MonitoredTrainingSession。 –