2017-06-19 131 views
0

我正在關注TensorFlow的Generative Adversarial Network的教程。本教程使用MNIST數據集來訓練模型。我想減少輸入的大小,以便我的程序運行速度更快,但不知道如何獲取我正在使用的MNIST數據集的子集。下面是我用於提取所述數據集的代碼:如何子集MNIST數據集?

from tensorflow.examples.tutorials.mnist import input_data 
mnist = input_data.read_data_sets("MNIST_data/") 

回答

0

有一種方法

mnist.next_batch(batchsize) 

提取從列車組的長度BATCHSIZE的隨機樣本。

如果你不想要的東西是隨機的,您可以通過

x = mnist.train.images[start_batch:end_batch] 
y = mnist.train.labels[start_batch:end_batch] 

或類似與mnist.test訪問它們的測試集。

+0

嗨,非常感謝您的回覆。我能夠使用您提供的方法對train.images和train.labels進行子集劃分。但是,在將這些數據集分組後,我得到一個NDArray對象,並且我無法調用爲ndarray的mnist數據集設計的任何方法。有沒有什麼辦法可以將ndarray放回mnist數據集? – nnguyen24