2017-06-06 110 views
0

我試圖用TensorFlow矩陣顯示自定義消息錯誤,當矩陣具有等於0的行列式時,逆算法無法計算但無法實現以顯示消息錯誤功能。我的代碼的結構如下所示:如何使用Tensorflow自定義錯誤消息

import tensorflow as tf 
def inversematricx(arg): 
    args = tf.convert_to_tensor(arg, dtype=tf.float32) 
    try: 
     return tf.matrix_inverse(args) 
    except: 
     raise ValueError("Determinant is 0. Input is not invertible") 

mat1=tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) # Determinant is 0 for mat1 
mat2=tf.constant([[1.0, 2.0, 4.0], [3.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) 

inverse=inversematricx(mat1) 

with tf.Session() as sess: 

    result = sess.run(inverse)  
    print(result) 

結果MAT2

[[0.17647055 -0.82352936 0.47058824] [-0.88235289 1.11764693 -0.35294116] [0.64705878 -0.35294113 0.05882351]]

但對於mat1,其行列式等於0,我想強制輸出 作爲ValueError消息,而不是產生的錯誤:

InvalidArgumentError: Input is not invertible. 
    [[Node: MatrixInverse_21 = MatrixInverse[T=DT_FLOAT, adjoint=false, _device="/job:localhost/replica:0/task:0/cpu:0"](Const_69)]] 

Caused by op 'MatrixInverse_21', defined at: 
    File "D:\WinPython-64bit-3.5.3.0Qt5\python-3.5.3.amd64\lib\site-packages\spyder\utils\ipython\start_kernel.py", line 227, in <module> 
main() 
.... 
InvalidArgumentError (see above for traceback): Input is not invertible. 
[[Node: MatrixInverse_21 = MatrixInverse[T=DT_FLOAT, adjoint=false, _device="/job:localhost/replica:0/task:0/cpu:0"](Const_69)]] 
+0

一個方法是使用'tf.cond'(https://www.tensorflow.org/api_docs/python/tf/cond)來檢查行列式是否非零,並且用'tf.Print'(https://www.tensorflow.org/api_docs/python/tf/Print)打印信息 –

+0

目前,它不是可能會捕捉圖例外。參考[this](https://github.com/tensorflow/tensorflow/issues/10332) – frankyjuang

回答

0

我找到了解決方案與tf.Print像如下自定義函數:

sess = tf.InteractiveSession() 
def checkMatrixInverse(arg): 
    f=tf.matrix_determinant(arg).eval() #get determinant value 
    args = tf.convert_to_tensor(arg, dtype=tf.float32) 
    inv=tf.matrix_inverse(args) 
    err='Input is not invertible:'  
    if(f==0): 
     return tf.Print(err,[err], name="NotInvertible") 
    else: 
     return tf.Print(inv, [inv], name="Inverse") 

noninverse=checkMatrixInverse(mat1) #output b'Input is not invertible:' 

inverse=checkMatrixInverse(mat2) 
#output: 
#[[ 0.17647055 -0.82352936 0.47058824] 
# [-0.88235289 1.11764693 -0.35294116] 
# [ 0.64705878 -0.35294113 0.05882351]]