2017-08-29 89 views
1

我正在構建一個神經網絡模型,其數據矩陣的列數(行爲)相同但行數不同,因此訓練標籤的大小也不相同。我使用logit作爲損失函數的交叉熵。Tensorflow - softmax與不同大小的數據

例如,我有數據是這樣的:

data1 = np.array([[0.1,0.2,0.3],[0.2,0.3,0.4],[0.3,0.4,0.5]]) 
data2 = np.array([[0.2,0.3,0.4],[0.3,0.4,0.5]]) 
label_1 = np.array([0,0,1]) 
label_2 = np.array([1,0]) 

所以我們有尺寸(3×)和偏差b的權重向量,並且我們的損失函數是:

loss = -1*(log(label_1*softmax(data1*weight+b)) + log(label_2*softmax(data2*weight+b))) 

在Tensorflow,我知道我可以定義一個佔位符無維度,例如:

tf_data = tf.placeholder(tf.float32, shape=(batch_size, None, feature_size)) 
tf_labels = tf.placeholder(tf.float32, shape=(batch_size, None)) 

我的問題是,我怎麼能養活DAT a到feed_dict?由於data = np.array([data1,data2])返回numpy的陣列中存儲兩個numpy的陣列和導致:

ValueError: setting an array element with a sequence 

此外,什麼tensorflow函數I可以用它來計算data*weight? tf.matmul(data,weight)結果在形狀必須秩2但爲秩3「MATMUL」

回答

0

看起來你的輸入數據是一個序列,你的輸出數據是一個序列,它的輸入序列中有許多元素。

我的問題是,如何將數據提供給feed_dict?由於數據= np.array([DATA1,DATA2])

出現錯誤,因爲data_2第二尺寸比的data_1第二尺寸。您可以填充data_2label_2以使其具有與data_1label_1分別相同的形狀。

此外,您應該定義一個掩碼佔位符,以確保label_2的填充部分未納入您的損失函數的計算中。

所以我們有尺寸(3X1)的權重向量

這是行不通的。你的體重矢量應該有形狀(input_dim, output_dim)。在這種情況下,您似乎有二維輸入的二維變化。對於這個維度,你應該儘可能的使用序列長度(因爲你會將較小的輸入值填充到這個值)。輸出尺寸也將是最大序列長度。說這是4,那麼W的形狀應該是:(3, 4, 4)b的形狀應該是(4,)

此外,什麼張量流函數我可以用來計算數據*權重? tm.matmul(數據,重量)導致Shape必須爲2級,但爲'MatMul'排名爲3

tf.matmul是正確的。當數據已經形成(batch_size, 3, 4)和重量形狀(3, 4, 4)。結果將形成(batch_size, 4)

這是你的代碼的更新版本(再次假設最大序列長度爲4):

data1 = np.array([[0.1,0.2,0.3],[0.2,0.3,0.4],[0.3,0.4,0.5]]) 
data2 = np.array([[0.2,0.3,0.4],[0.3,0.4,0.5]]) 
data1_padded = np.concatenate([data1, np.zeros(shape=(1,3), dtype=np.float32)], 0) 
data2_padded = np.concatenate([data2, np.zeros(shape=(2,3), dtype=np.float32)], 0) 

label_1 = np.array([0,0,1.]) 
label_1_padded = np.concatenate([label_1, np.zeros(shape=(1), dtype=np.float32)], 0) 
label_2 = np.array([1.,0]) 
label_2_padded = np.concatenate([label_2, np.zeros(shape=(2), dtype=np.float32)], 0) 

mask_1 = np.array([1., 1., 0., 0.]) 
mask_2 = np.array([1., 1., 1., 0.]) 

話雖這麼說,你提出了網絡架構(前饋神經網絡)是不是很適合序列。我建議你看看遞歸神經網絡和/或卷積神經網絡。實現將類似,你也需要填充。

相關問題