我有訓練有素的深層神經網絡,帶有句級關注層。如下所示,網絡被稱爲GRU,我想在測試後得到關注值(sen_alpha)的結果。無法在TensorFlow中以.py格式保存numpy,並將python作爲cPickle.PicklingError
class GRU:
def __init__(self,is_training,word_embeddings,settings):
self.big_num = big_num = settings.big_num
for i in range(big_num):
sen_repre.append(tf.tanh(attention_r[self.total_shape[i]:self.total_shape[i+1]]))
batch_size = self.total_shape[i+1]-self.total_shape[i]
sen_alpha.append(tf.reshape(tf.nn.softmax(tf.reshape(tf.matmul(tf.mul(sen_repre[i],sen_a),sen_r),[batch_size])),[1,batch_size]))
self.attentions.append(sen_alpha[i])
測試代碼:
def main(_):
test_settings = Settings()
with tf.Graph().as_default():
sess = tf.Session()
with sess.as_default():
with tf.variable_scope("model"):
mtest = GRU(is_training=False, word_embeddings = None, settings = test_settings)
saver = tf.train.Saver()
attentions = mtest.attentions
att = np.array(attentions)
print(str(type(att)))
print(att[0:100])
np.save("attentions.npy",att)
結果:
類型:類型numpy.ndarray'
ATT [0:100]:
[<tf.Tensor 'model/Reshape_9:0' shape=(1, ?) dtype=float32<tf.Tensor 'model/Reshape_17:0' shape=(1, ?) dtype=float32<tf.Tensor 'model/Reshape_25:0' shape=(1, ?) dtype=float32>
錯誤:
文件 「test_GRU.py」,線路242,在主 np.save( 「attentions.npy」,ATT)
cPickle.PicklingError:不能鹹菜:屬性查找內置 .module失敗
如何正確保存結果?由於
爲什麼你甚至使用'tf.Tensor'對象的'numpy'對象數組?這沒什麼意義。只是我們的列表或其他東西。 –
'att'可能是一個對象dtype數組。也就是說,一個包含一個或多個指向'關注'對象的指針的數組。 'np.save'使用'pickle'來保存對象。雖然它可以直接向文件寫入一個數字數據庫,但它必須使用'pickle'來創建一個字節字符串。我的猜測是'tf.Tensor'沒有定義酸洗方法。檢查Tensorflow是否有自己定義的保存方法。 – hpaulj