2017-06-05 102 views
0

讓v是一個張量。如果我計算另一個張量w.r.t對v的梯度,一切正常,即Tensorflow:計算梯度w.r.t. sub-tensor

grads = tf.gradients(loss_func, v) 

工作正常。

然而,當我想計算梯度WRT只是一個單一的元素或V的任何子張,我得到一個錯誤,即

grads = tf.gradients(loss_func, v[0,0]) 
grads = tf.gradients(loss_func, v[:,1:]) 

產生以下錯誤:

Traceback (most recent call last): 
    File "<stdin>", line 1, in <module> 
    File "/Users/henning/anaconda/lib/python3.5/site-packages/spyder/utils/site/sitecustomize.py", line 866, in runfile 
    execfile(filename, namespace) 
    File "/Users/henning/anaconda/lib/python3.5/site-packages/spyder/utils/site/sitecustomize.py", line 102, in execfile 
    exec(compile(f.read(), filename, 'exec'), namespace) 
    File "/Users/henning/pflow/testing.py", line 89, in <module> 
    theta = sess.run(grads, feed_dict={P:P_inp, Q:Q_inp}) 
    File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 767, in run 
    run_metadata_ptr) 
    File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 952, in _run 
    fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string) 
    File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 408, in __init__ 
    self._fetch_mapper = _FetchMapper.for_fetch(fetches) 
    File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 230, in for_fetch 
    return _ListFetchMapper(fetch) 
    File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 337, in __init__ 
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 
    File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 337, in <listcomp> 
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 
    File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 227, in for_fetch 
    (fetch, type(fetch))) 
TypeError: Fetch argument None has invalid type <class 'NoneType'> 

我究竟做錯了什麼?

回答

1

我找到了解決問題的辦法。

我找到的最優雅的方法是從常量和變量中'構造'v,然後計算漸變w.r.t.變量,即

v_free = tf.Variable(shape) 
v_notfree = tf.constant(other_shape) 
v = tf.concat([v_notfree, v_free]) 
loss_func = some function of v 
grads = tf.gradients(loss_func, v_free) 
+0

我可以問你(與這個主題無關) - 是否w.r.t.意味着什麼?因爲我第n次碰到這個快捷方式,考慮了機器學習的話題,這裏的人們使用它很多,所以我開始思考 – user3613919