2017-08-25 73 views
1

一定的乘法我有尺寸[a,b,c,d]的一個張量A和尺寸[b,b,d,e]的另一BC[a]整數從0到b列表。我需要製作尺寸[a,b,c,e]的張量D通過如何實現在tensorflow

D[i,j,k,l] = sum for m=0..d of A[i,C[i],k,m] * B[C[i],j,m,l]

b給予足夠小(3或5,通常是?)我不介意在b獨立的操作做這一點 - 但我不能如果要花費b^2的記憶或時間,當這種操作在b中顯然應該是線性的時候,不能承擔浪費。這似乎是點式乘法(包括廣播?)和張量收縮(矩陣乘以公共m維度)的一些組合,但我不能把它縮小。

如果有人能真正說服我,在tensorflow提供的操作中O(b) flops是不可能的,那麼沒關係,但是我肯定想要一個O(b^2)

更新:它看起來像適當修改A張量可以使用tf.gather_nd單獨構建;如果這可以與B以某種方式配對,也許?不幸的是,我迄今爲止的實驗導致在tf.gather_nd本身發現了一個錯誤,這已經減慢了速度。

回答

1

我想出瞭如何做到這一點,合理高效。首先建立的B修改後的版本與tf.gather,在第一索引中的相應部分:

B2 = tf.gather(B, C) 

然後拉出只是用tf.gather_ndA張量的相關部分。我們將推出一系列[0,C[0]], [1,C[1]], [2,C[2]]...等索引對,因此首先我們需要構建索引張量。

a = tf.shape(A)[0] 
A2_indices = tf.stack([tf.range(a), C], axis=0) 
A2 = tf.gather_nd(A, A2_indices) 

具有形狀[a,c,d]產生A2。現在我們需要適當地乘以A2B2。這是m指數(分別爲2和3)中的張量收縮,但在i指數(兩者均爲0)中是逐點乘法。這意味着,可悲的是,結果項不是張量收縮或逐點乘法!一種選擇是計算張量積,然後只在m之上收縮,然後在i兩個指數上取tf.diag - 但這會浪費大量計算來構建我們不需要的矩陣的其餘部分。相反,我們可以將其視爲成批矩陣乘法:這通常稱爲​​,但現在它只是matmul。儘管如此,有一點需要注意,除了每個輸入張量中的2個矩陣維外,其餘的都必須逐點相乘。 BB2不符合此標準,因爲它們具有附加的j索引。但是,我們可以用l輸出維度「包裝」,然後再刪除它。這意味着首先調用tf.transposejl彼此相鄰,然後tf.reshape變成一個j*l輸出尺寸,然後做tf.matmul,然後又tf.reshapetf.transpose返回到原來的形式。所以

a, b, d, e = B2.get_shape().as_list() 
B2_trans = tf.transpose(B2, perm=[0,2,1,3]) 
B2_jl = tf.reshape(B2, [a,d,b*e]) 
product_jl = tf.matmul(A2, B2_jl) 
product_trans = tf.reshape(product_jl, [a,d,b,e]) 
result = tf.transpose(product_trans, perm=[0,2,1,3]) 

其中完成吧!當然在實踐中很可能只有B在這一個實例中需要,在這種情況下,B可能已經開始處於「壓縮」狀態,節省了轉置(以及便宜的重塑)。或者如果A2將被平鋪或轉置,那麼它也可以節省轉置。但總體而言,一切都非常複雜。 :)