2017-01-10 43 views
0

當在圖構建階段時,假設張量x這是一個神經網絡的完全連接層構建第一軸形狀的條件圖在張量流中爲「無」

因此假設x的形狀是(?, 5)。我想設置的最後一列像這樣的Python:

for i in range(x.shape[0]): 
    if x[i,-1] < 0.5: 
     x[i,-1] = 0.0 
    else: 
     x[i,-1] = 1.0 

我只能用tf.condx只有1行是這樣的:

# const3 and const4 are constant mask 
out = tf.cond(tf.greater(out[-1], tf.constant(0.5)), 
       lambda: tf.add(tf.multiply(out, const3), const4), 
       lambda: tf.multiply(out, const3)) 

如何檢查的x時,第一個維度以上是?

回答

2

是否這樣?

import tensorflow as tf 
import numpy as np 

a = tf.placeholder(tf.int32, shape=[None, 5]) 

r, c = a.get_shape() 
x_split = tf.split(1, c, a) # split a along axis 1 
last_col = x_split[-1] 


mask = tf.greater(last_col, tf.constant(6)) 
cond = tf.where(mask, 
       tf.add(last_col, tf.constant(1)), # if true, add one 
       tf.add(last_col, tf.constant(-1)))# if false, minus one 

x_split = x_split[0:-1] 
x_split.append(cond) 
ans = tf.concat(1, x_split) 


with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    arr = np.array([[1, 2, 3, 4, 5], 
        [6, 7, 8, 9, 10]]) 
    the_ans = sess.run(ans, feed_dict={a: arr}) 
''' 
the_ans is 
[[ 1 2 3 4 4] 
[ 6 7 8 9 11]] 
''' 
+0

非常感謝您的回答,這就是我所尋找的:-) –

相關問題