4

我想在某些時間序列數據上運行GRU單元,以根據最後一層中的激活對它們進行聚類。我做了一個小改動,GRU單元實現當獲取可變序列長度的激活時,Tensorflow GRU單元格錯誤

def __call__(self, inputs, state, scope=None): 
"""Gated recurrent unit (GRU) with nunits cells.""" 
with vs.variable_scope(scope or type(self).__name__): # "GRUCell" 
    with vs.variable_scope("Gates"): # Reset gate and update gate. 
    # We start with bias of 1.0 to not reset and not update. 
    r, u = array_ops.split(1, 2, linear([inputs, state], 2 * self._num_units, True, 1.0)) 
    r, u = sigmoid(r), sigmoid(u) 
    with vs.variable_scope("Candidate"): 
    c = tanh(linear([inputs, r * state], self._num_units, True)) 
    new_h = u * state + (1 - u) * c 

    # store the activations, everything else is the same 
    self.activations = [r,u,c] 
return new_h, new_h 

這之後我串接在以下方式激活之前,我在調用該GRU細胞

@property 
def activations(self): 
    return self._activations 


@activations.setter 
def activations(self, activations_array): 
    print "PRINT THIS"   
    concactivations = tf.concat(concat_dim=0, values=activations_array, name='concat_activations') 
    self._activations = tf.reshape(tensor=concactivations, shape=[-1], name='flatten_activations') 

劇本歸還我調用GRU以下面的方式

outputs, state = rnn.rnn(cell=cell, inputs=x, initial_state=initial_state, sequence_length=s) 

哪裏s細胞是批次長度的陣列與時間戳的輸入批量的每個元件的數量。

最後,我提取採用了

fetched = sess.run(fetches=cell.activations, feed_dict=feed_dict) 

在執行我收到以下錯誤

回溯(最近通話最後一個): 文件 「xxx.py」,線路162,在 取= sess.run(fetches = cell.activations,feed_dict = feed_dict) 文件「/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py」,行315,運行中 return self._run(None,fetches,feed_dict) 文件「/xxx/local/lib/python2.7/site-pac kages/tensorflow/python/client/session.py「,第511行,在_run feed_dict_string) 文件」/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py「,第564行,在_do_run target_list中) 文件「/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py」,第588行,在_do_call中 six.reraise(e_type,e_value ,e_traceback) 文件「/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py」,第571行,在_do_call中 return fn(* args) File「/ xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py「,第555行,在_run_fn中

return tf_sessi on.TF_Run(會話,feed_dict,fetch_list,target_list) tensorflow.python.pywrap_tensorflow.StatusNotOK:無效參數:爲RNN/cond_396/ClusterableGRUCell/flatten_activations返回的張量:0無效。

有人可以告訴我們如何在最後一步從GRU單元中獲取激活,並傳遞可變長度序列嗎?謝謝。

回答

0

要從最後一步獲取激活,您想使激活成爲您的狀態的一部分,這是tf.rnn返回的。

相關問題