2017-10-08 280 views
0

我想數據增強功能添加到TensorFlow MNIS例如mnist_deep.py使用tf.contrib.image.rotate()如何將標量張量轉換爲TensorFlow模型中的標量?

rotate_angle = 0.1 


def deepnn(x): 
    ... 
    with tf.name_scope('rotate'): 
     angle = tf.tf.placeholder(tf.float32) 
     x_image = tf.contrib.image.rotate(x_image, angle) # Wrong! 
    ... 
    return angle 


... 
angle = deepnn(x) 
with tf.Session() as sess: 
    angle.eval({angle: rotate_angle} 

這不起作用,因爲tf.contrib.image.rotate()只接受簡單的標量的角度。

我試過TensorFlow: cast a float64 tensor to float32但遺憾的是現在提到的函數還返回一個張量。

我應該如何將張量標量轉換爲模型本身的標量?我想重複使用相同的模型,併爲培訓和測試提供不同的角度。

回答

0

我不認爲你需要奇怪的轉換,但一些代碼的重新組織。我發現了一個可能的解決您的問題,我希望它是適合你:

import tensorflow as tf 
import numpy as np 

rotate_angle = 0.1 

def deepnn(x,angle): 
    x_image = tf.contrib.image.rotate(x, angle)  
    return x_image 

angle = tf.placeholder(tf.float32,shape=()) 
input_image_placeholder = tf.placeholder(tf.float32,shape=(100,100,3)) 


rotated_x_image = deepnn(input_image_placeholder,angle) 


sess = tf.Session() 

input_image = np.ones(dtype=float,shape=(100,100,3)) 

curr_rotated_x_image = sess.run(rotated_x_image,{angle:rotate_angle,input_image_placeholder:input_image}) 

print(curr_rotated_x_image) 

sess.close() 

我不認爲聲明的佔位符裏面的函數是一個好主意,所以我外移它。讓我知道這個解決方案是否可以!