2017-08-08 48 views
1

我知道keras image_ocr模型。它採用圖像產生的圖像,但是,我面臨着一些困難,因爲我想給我自己的數據集模型的training.vi如何給自己的數據集keras image_ocr

回購鏈接是:https://github.com/fchollet/keras/blob/master/examples/image_ocr.py

我已經創建數組: x和y。我的圖像路徑及其相應的gt是在一個csv文件中。

x被給定作爲圖像的尺寸: [nb_samples,W,H,C]

y被給定它是一個字符串標籤中,GT。

這裏是我使用的代碼預處理:

for i in range(0,len(read_file)): 
    path = read_file['path'][i] 
    label = read_file['gt'][i] 
    path = path.strip('\n') 
    img = cv2.imread(path,0) 
    #Re-sizing the images 
    #height = 64, width = 128 
    #res_img = cv2.resize(img, (128,64)) 
    #cv2.imwrite(i,res_img) 
    h,w = img.shape 
    x.append(img) 
    y.append(label) 
    size = img.size 
    """ 
    print "Height: ", h #Height 
    print "Width: ", w #Width 
    print "Channel: ", C#Channel 
    print "Size: ", size 
    print "\n" 
    """ 
print "H: ", h 
print "W: ", w 
print "S: ", size 

x = np.array(x).astype(np.float32) 
y = np.array(y) 

x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=0.3,random_state=42) 

x_train = np.array(x_train).astype(np.float32) 
y_train = np.array(y_train) 
x_train = np.array(x_train) 
x_test = np.array(x_test) 
y_test = np.array(y_test) 

print "Printing the shapes. \n" 
print "X_train shape: ", x_train.shape 
print "Y_train shape: ", y_train.shape 
print "X_test shape: ", x_test.shape 
print "Y_test shape: ", y_test.shape 
print "\n" 

其次是在keras image_ocr代碼。總的代碼是在這裏: https://gist.github.com/kjanjua26/b46388bbde9ded5cf1f077a9f0dedc4f

的錯誤,當我運行是這樣的:

`Traceback (most recent call last): 
File "preprocess.py", line 323, in <module> 
train(run_name, 0, 20, w) 
File "preprocess.py", line 314, in train 
model.fit(next_train(x_train), y_train, batch_size=7, epochs=20,  verbose=1, validation_split=0.1, shuffle=True, initial_epoch=0) 
File "/home/kamranjanjua/anaconda2/lib/python2.7/site- packages/keras/engine/training.py", line 1358, in fit 
batch_size=batch_size) 
File "/home/kamranjanjua/anaconda2/lib/python2.7/site-packages/keras/engine/training.py", line 1234, in _standardize_user_data 
exception_prefix='input') 
File "/home/kamranjanjua/anaconda2/lib/python2.7/site-packages/keras/engine/training.py", line 100, in _standardize_input_data 
'Found: ' + str(data)[:200] + '...') 
TypeError: Error when checking model input: data should be a Numpy array, or list/dict of Numpy arrays. Found: <generator object next_train at 0x7f8752671640>...` 

任何幫助,將不勝感激。

回答

1

如果您仔細查看代碼,您將能夠看到該模型需要一個字典作爲其輸入。

inputs = {'the_input': X_data,'the_labels': labels, 'input_length': input_length,'label_length': label_length,'source_str': source_str} 

outputs = {'ctc': np.zeros([size])} # dummy data for dummy loss function 

對於輸入: 1)X_data是訓練實例 2)標籤是相應的訓練示例的標籤 3)label_length是標籤的長度 4)Input_Length是您輸入的長度 5 )源字符串是它不是強制性的,它只是用於解碼

輸出是

現在在你的代碼,你只產生X_train,y_train爲CTC損失功能的虛擬數據,但OTH呃輸入丟失。您需要根據模型的預期輸入和輸出準備數據集。

相關問題