2017-08-02 42 views
0

我有訓練有素的深層神經網絡,帶有句級關注層。如下所示,網絡被稱爲GRU,我想在測試後得到關注值(sen_alpha)的結果。無法在TensorFlow中以.py格式保存numpy,並將python作爲cPickle.PicklingError

class GRU: 
def __init__(self,is_training,word_embeddings,settings): 

    self.big_num = big_num = settings.big_num  
    for i in range(big_num): 

     sen_repre.append(tf.tanh(attention_r[self.total_shape[i]:self.total_shape[i+1]])) 
     batch_size = self.total_shape[i+1]-self.total_shape[i] 
       sen_alpha.append(tf.reshape(tf.nn.softmax(tf.reshape(tf.matmul(tf.mul(sen_repre[i],sen_a),sen_r),[batch_size])),[1,batch_size])) 
       self.attentions.append(sen_alpha[i]) 

測試代碼:

def main(_): 
test_settings = Settings() 
with tf.Graph().as_default(): 

    sess = tf.Session() 
    with sess.as_default():  
     with tf.variable_scope("model"): 
          mtest = GRU(is_training=False, word_embeddings = None, settings = test_settings) 
        saver = tf.train.Saver() 

      attentions = mtest.attentions 
      att = np.array(attentions)  
      print(str(type(att))) 
      print(att[0:100]) 
      np.save("attentions.npy",att) 

結果:

類型:類型numpy.ndarray'

ATT [0:100]:

[<tf.Tensor 'model/Reshape_9:0' shape=(1, ?) dtype=float32<tf.Tensor 'model/Reshape_17:0' shape=(1, ?) dtype=float32<tf.Tensor 'model/Reshape_25:0' shape=(1, ?) dtype=float32> 錯誤:

文件 「test_GRU.py」,線路242,在主 np.save( 「attentions.npy」,ATT)

cPickle.PicklingError:不能鹹菜:屬性查找內置 .module失敗

如何正確保存結果?由於

+0

爲什麼你甚至使用'tf.Tensor'對象的'numpy'對象數組?這沒什麼意義。只是我們的列表或其他東西。 –

+0

'att'可能是一個對象dtype數組。也就是說,一個包含一個或多個指向'關注'對象的指針的數組。 'np.save'使用'pickle'來保存對象。雖然它可以直接向文件寫入一個數字數據庫,但它必須使用'pickle'來創建一個字節字符串。我的猜測是'tf.Tensor'沒有定義酸洗方法。檢查Tensorflow是否有自己定義的保存方法。 – hpaulj

回答

0

我無法修復你的代碼,但我可以給你一步一步從設計模型定義的短版從中提取值:

  1. 定義模型圖。這意味着GRU是該圖的一部分。
  2. 開始會話,例如, sess = tf.Session()
  3. 初始化圖的變量,例如, sess.run(tf.global_variables_initializer())
  4. 使用會話方法從相應的圖表中獲取值,例如, sess.run(the_tensor, dictionary_of_numpy_array_as_input_to_graph)

輸出將是numpy數組,您可以保存它們。

相關問題