2016-11-04 72 views
0

我納悶這WRT topicTheano功能相當於Tensorflow

我想解決更新的Theano.function問題與此懶tensorflow CONSTRUTION:

class TensorFlowTheanoFunction(object): 
def __init__(self, inputs, outputs, session): 
    self._inputs = inputs 
    self._outputs = outputs 
    self.session = session 

def __call__(self, *args, **kwargs): 
    feeds = {} 
    for (argpos, arg) in enumerate(args): 
     feeds[self._inputs[argpos]] = arg 
    return self.session.run(self._outputs, feeds) 

如果我想要通過更新參數(如在Theano中)如何修改此懶惰呼叫? 我只是想,這也可以在tensorflow工作:

self.new = theano.function([], [], updates=zip(old_params, params)) 

回答

1

只需修改從該線程雅羅斯拉夫的代碼中使用tf.assign,有控制的依賴,以確保輸出計算的分配發生之前:

import tensorflow as tf 

class TensorFlowTheanoFunction(object): 
    def __init__(self, inputs, outputs, updates=()): 
    self._inputs = inputs 
    self._outputs = outputs 
    self._updates = updates 

    def __call__(self, *args, **kwargs): 
    feeds = {} 
    for (argpos, arg) in enumerate(args): 
     feeds[self._inputs[argpos]] = arg 
    try: 
     outputs_identity = [tf.identity(output) for output in self._outputs] 
     output_is_list = True 
    except TypeError: 
     outputs_identity = [tf.identity(self._outputs)] 
     output_is_list = False 
    with tf.control_dependencies(outputs_identity): 
     assign_ops = [tf.assign(variable, replacement) 
        for variable, replacement in self._updates] 
    outputs_list = tf.get_default_session().run(
     outputs_identity + assign_ops, feeds)[:len(outputs_identity)] 
    if output_is_list: 
     return outputs_list 
    else: 
     assert len(outputs_list) == 1 
     return outputs_list[0] 

a = tf.placeholder(dtype=tf.int32) 
b = tf.placeholder(dtype=tf.int32) 
variable = tf.get_variable(
    "variable", shape=[], dtype=tf.int32, initializer=tf.zeros_initializer) 
c = a + b + variable 
d = a - b 
sess = tf.InteractiveSession() 
sess.run(tf.initialize_all_variables()) 
f = TensorFlowTheanoFunction([a, b], [c, d], updates=[(variable, variable + 1)]) 
print f(1, 2) 
print f(1, 2) 
print f(0, 2) 
f = TensorFlowTheanoFunction([a, b], c, updates=[(variable, variable + 1)]) 
print f(1, 2) 
print f(1, 2) 
print f(0, 2) 

這將更新在每次迭代變量:

[3, -1] 
[4, -1] 
[4, -2] 
6 
7 
7 
+0

HM,我有一個錯誤與我以前的代碼WRT這個新功能實現: **提高類型錯誤( 「 '張量' 對象不是可迭代。」)** 一些堆棧: 'prev_call> VAL = TensorFlowTheanoFunction([self.input],self.get_output(),會話)' 'err_call> outputs_identity = [tf.identity(output)for self in.outputs]' – Glau

+1

啊,對不起。當輸出是標量時,我添加了一個特例。 –

+0

對不起。在我看來,它在這些更正後有效,但速度稍慢(100-200倍)。無論如何,thx求助! – Glau