2017-06-15 88 views
1

我創建了一個稱爲CustomFunc的自定義功能,說明這裏下面:https://www.cntk.ai/pythondocs/extend.html如何編寫自定義函數CNTK

如果我使用它的文章的建議,它的工作原理:

model = cntk.user_function(CustomFunc(prev_node)) 

這個作品很好,模型運行沒有任何問題。我的問題是,我想在cntk.layers.Sequential調用中使用此函數,並在cntk.layers.Recurrence調用中使用此函數。要做到這一點,我需要以另一種方式構建函數的組合,然後將其放入Sequential或Recurrence調用中。現在我使用一些佔位符,即我做的是:

customFunToUse = cntk.user_function(CustomFunc(cntk.placeholder(), otherInputs)) 
model = cntk.layers.Sequential([cntk.layers.Dense(100), 
           customFunToUse, 
           cntk.layers.Recurrence(
           customFunToUse >> cntk.layers.LSTM(100))]) 

但是,這並不工作,並提出了各種錯誤:有時它是一個段錯誤,在其他類似型號是

"ValueError: Cannot create an NDArrayView using a view shape '[? x 10]' that has unknown dimensions for any of its axes." 
而不是

其他時間是

Evaluate: All nodes inside a recurrent loop must have a layout that is identical; mismatch found for nodes ... 

還要注意的是我的自定義功能不改變輸入尺寸:給予paramters的任何金額,它會返回相同的數量和類型。該代碼是這樣的:

class CustomFun(UserFunction): 
    def __init__(self, *args, otherStuff, name='CustomFun'): 
     super(CustomFun, self).__init__(list(args), name=name) 
     self.otherStuff = otherStuff 

    def forward(self, arguments, outputs=None, keep_for_backward=None, device=None, as_numpy=True): 
     return None,[x/2 for x in arguments] 

    def backward(self, state, root_gradients, variables=None, as_numpy=True): 
     #it's not important right now, just a test... 
     return root_gradient 

    def infer_outputs(self): 
     #shape, type and dynamic axes of inputs are not changed by this function 

     outputVar = [output_variable(self.inputs[idx].shape, self.inputs[idx].dtype, 
      self.inputs[idx].dynamic_axes, name='out_quantLayer') for idx in range(len(self.inputs))] 
     return outputVar 

    def serialize(self): 
     return {'otherStuff': self.otherStuff} 

    @staticmethod 
    def deserialize(inputs, name, state): 
     return CustomFun(inputs, otherStuff=state['otherStuff'], name=name) 

回答

1

正確的方法是寫這樣的 def my_layer(x): @C.Function def apply(x): return cntk.user_function(CustomFunc(x)) return apply 不幸的是這似乎導致我的Python解釋器崩潰。我已經在此打開github issue 2132。問題得到解決後,將嘗試更新此答案。

更新:有一個我們沒有捕捉到的小錯字。在github問題頁面有一個解決方案。