2017-06-15 1038 views
0

我目前正在通過實施深度轉網絡來研究kaggle上的cats vs dogs分類任務。下面的代碼行用於數據預處理:如何在python中爲自定義數據實現next_batch()函數

def label_img(img): 
    word_label = img.split('.')[-3] 
    if word_label == 'cat': return [1,0] 
    elif word_label == 'dog': return [0,1] 

def create_train_data(): 
    training_data = [] 
    for img in tqdm(os.listdir(TRAIN_DIR)): 
     label = label_img(img) 
     path = os.path.join(TRAIN_DIR,img) 
     img = cv2.resize(cv2.imread(path,cv2.IMREAD_GRAYSCALE),IMG_SIZE,IMG_SIZE)) 
     training_data.append([np.array(img),np.array(label)]) 

    shuffle(training_data) 
    return training_data 

train_data = create_train_data() 

X_train = np.array([i[0] for i in train_data]).reshape(-1, IMG_SIZE,IMG_SIZE,1) 
Y_train =np.asarray([i[1] for i in train_data]) 

我想要實現複製在tensorflow深MNIST教程

batch = mnist.train.next_batch(100) 

回答

0

code提供了以下功能的功能就是一個很好的例子拿出生成批處理的功能。

簡單說明,你只需要拿出兩個數組的x_train和y_train喜歡:

batch_inputs = np.ndarray(shape=(batch_size), dtype=np.int32) 
    batch_labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) 

並設置列車數據,如:

batch_inpouts[i] = ... 
    batch_labels[i, 0] = ... 

最後通過數據設置爲會話:

_, loss_val = session.run([optimizer, loss], feed_dict={train_inputs: batch_inputs, train_labels:batch_labels}) 
+0

請試試看。謝謝你的時間。 –

2

除了生成一個批處理,您可能還想隨機重新安排數據每批次。

EPOCH = 100 
BATCH_SIZE = 128 
TRAIN_DATASIZE,_,_,_ = X_train.shape 
PERIOD = TRAIN_DATASIZE/BATCH_SIZE #Number of iterations for each epoch 

for e in range(EPOCH): 
    idxs = numpy.random.permutation(TRAIN_DATASIZE) #shuffled ordering 
    X_random = X_train[idxs] 
    Y_random = Y_train[idxs] 
    for i in range(PERIOD): 
     batch_X = X_random[i * BATCH_SIZE:(i+1) * BATCH_SIZE] 
     batch_Y = Y_random[i * BATCH_SIZE:(i+1) * BATCH_SIZE] 
     sess.run(train,feed_dict = {X: batch_X, Y:batch_Y}) 
+0

非常感謝。最後,我可以正確地訓練我的網絡。 –

+0

你能否啓發我tensorflow的next_batch()返回什麼?它是指定批次大小的訓練集中的隨機數據集合嗎?如果是這樣,它確保不重複? @Joshua Lim –

+0

next_batch()是一個專門針對tensorflow提供的MNIST教程的函數。它的工作原理是在開始時隨機化訓練圖像和標籤對,並在每次調用函數時選擇每個後續100張圖像。一旦達到最後,圖像標籤對就會再次被隨機化,並且重複該過程。整個數據集只有在使用所有可用對時纔會重新洗牌並重復。 –

相關問題