2017-03-01 102 views
3

在Tensorflow source,我明白了,Tensorflow操作文檔

REGISTER_OP("BroadcastGradientArgs") 
    .Input("s0: T") 
    .Input("s1: T") 
    .Output("r0: T") 
    .Output("r1: T") 
    .Attr("T: {int32, int64} = DT_INT32") 
    .SetShapeFn([](InferenceContext* c) { 
     ... uninteresting details ... 
    }) 
    .Doc(R"doc(                            
Return the reduction indices for computing gradients of s0 op s1 with broadcast.            

This is typically used by gradient computations for a broadcasting operation.             
)doc"); 

在Python中,我能做到以下幾點,

>>> from tensorflow.python.ops import gen_array_ops 
>>> gen_array_ops._InitOpDefLibrary()._ops['BroadcastGradientArgs'].op_def 
name: "BroadcastGradientArgs" 
input_arg { 
    name: "s0" 
    type_attr: "T" 
} 
... more stuff ... 
attr { 
    name: "T" 
    type: "type" 
    ... uninteresting details ... 
} 

請注意,我在Python我得到的Protobuf定義(我刪除一些爲簡潔起見)的TF操作。我想獲取我在C++代碼中看到的定義的文檔部分。我如何得到它?

+0

TF源代碼中存在信息(.Doc),但它沒有通過python包裝導出,因此在Python中不可用。看[這裏](https://github.com/tensorflow/tensorflow/blob/a3e636c0f561e2ac6d9f8a0044fbe09acb003803/tensorflow/python/framework/python_op_gen.cc),這是你在python中使用的'op_def'生成的地方。一些代碼需要在這裏實現,所以你可以在Python中訪問這些文檔。 – Arash

+0

謝謝。看來https://github.com/tensorflow/tensorflow/blob/a3e636c0f561e2ac6d9f8a0044fbe09acb003803/tensorflow/core/framework/op_def_util.cc#L684是從OpDef中刪除描述的地方。這是真的?如果我刪除該行https://github.com/tensorflow/tensorflow/blob/a3e636c0f561e2ac6d9f8a0044fbe09acb003803/tensorflow/python/framework/python_op_gen.cc#L708,那麼OpDef的文檔仍然存在於Protobuf中? –

+0

對該行解除註釋會導致protobuf出現一系列問題。也許有更好的辦法。 https://github.com/google/protobuf/issues/2798 –

回答

1

這很痛苦。你需要的補丁TF和的Protobuf

https://github.com/tensorflow/tensorflow/issues/8207 https://github.com/google/protobuf/issues/2798

然後您還需要註釋掉線, https://github.com/tensorflow/tensorflow/blob/a3e636c0f561e2ac6d9f8a0044fbe09acb003803/tensorflow/python/framework/python_op_gen.cc#L708

重建並運行了,

bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package && \ 
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg && \ 
sudo pip install --upgrade /tmp/tensorflow_pkg/tensorflow-1.0.0*.whl 

和執行我的測試,

$ python -c "from tensorflow.python.ops import gen_array_ops; print gen_array_ops._InitOpDefLibrary()._ops['BroadcastGradientArgs'].op_def" 
name: "BroadcastGradientArgs" 
input_arg { 
    name: "s0" 
    type_attr: "T" 
} 
input_arg { 
    name: "s1" 
    type_attr: "T" 
} 
output_arg { 
    name: "r0" 
    type_attr: "T" 
} 
output_arg { 
    name: "r1" 
    type_attr: "T" 
} 
attr { 
    name: "T" 
    type: "type" 
    default_value { 
    type: DT_INT32 
    } 
    allowed_values { 
    list { 
     type: DT_INT32 
     type: DT_INT64 
    } 
    } 
} 
summary: "Return the reduction indices for computing gradients of s0 op s1 with broadcast." 
description: "This is typically used by gradient computations for a broadcasting operation." 

這個問題太多了正則表達式...