2017-07-03 69 views
3

我看了docs of tf.one_hot,發現TensorFlow,tf.one_hot爲什麼輸出的形狀是由axis的值定義的?

...。新軸在尺寸軸上創建(默認:新軸在最後附加)。

什麼是The new axis

如果索引是的長度特徵的載體,其輸出的形狀將是:

設有×深如果軸== -1

深×特徵如果軸== 0

如果指數是具有形狀[批次,特徵]的矩陣(批次),則輸出形狀將爲:

批次X功能x深如果軸== -1

批次×深×特徵如果軸== 1

深×批次X特徵如果軸== 0

爲什麼輸出的形狀是由軸定義的?

回答

6

tf.one_hot()轉換索引列表(例如[0, 2, 1])並將其轉換成長度爲depth的單向矢量列表。

例如,如果depth = 3,在輸入

  • 索引0將由[1,0,0]在輸入
  • 索引1將被替換被替換由[0,1,0 ]
  • 索引2在輸入將由[0,0,1]

所以[0, 2, 1]將被編碼爲[[1, 0, 0], [0, 0, 1], [0, 1, 0]]

0來替換

如您所見,輸出具有比輸入更多的維度(因爲每個索引都被一個向量所替代)。

默認情況下(和你通常需要)新的層面創建作爲最後一個,所以如果你的輸入形狀(d1, d2, .., dn)的,你的輸出就會形狀(d1, d2, .., dn, depth)的。但是,如果更改輸入參數,則可以選擇將其他位置的維度放在其他位置,例如,如果axis=0的輸出形狀爲(depth, d1, d2, .., dn)

更改維度的順序基本上是n維版本的轉置:您具有相同的數據,但切換索引的順序以訪問它們(等同於切換2D矩陣中的列和行) 。

+0

由於'如果輸入索引是等級N,輸出將具有等級N + 1,軸的值是輸出深度的索引 –