2017-02-11 82 views
19

這是我的previous question的後續行動,我問道爲什麼流合併不在某個程序中踢。事實證明,問題在於某些函數未被內聯,並且一個標誌將性能提高了約17x(它展示了內聯的重要性!)。有什麼方法可以內聯遞歸函數嗎?

現在,請注意,在原始問題上,我一次硬編碼64調用incAll。現在,假設,相反,我創建一個nTimes功能,反覆調用的函數:

module Main where 

import qualified Data.Vector.Unboxed as V 

{-# INLINE incAll #-} 
incAll :: V.Vector Int -> V.Vector Int 
incAll = V.map (+ 1) 

{-# INLINE nTimes #-} 
nTimes :: Int -> (a -> a) -> a -> a 
nTimes 0 f x = x 
nTimes n f x = f (nTimes (n-1) f x) 

main :: IO() 
main = do 
    let size = 100000000 :: Int 
    let array = V.replicate size 0 :: V.Vector Int 
    print $ V.sum (nTimes 64 incAll array) 

在這種情況下,只需添加一個INLINE編譯到nTimes不會幫助,因爲據我所知GHC不內聯遞歸功能。在編譯時強制GHC擴展nTimes是否有任何竅門,從而恢復預期的性能?

+2

您可以使用Template Haskell來引入語法來擴展重複的應用程序。 –

+1

@JoachimBreitner剛剛完成了這個。必須學習模板Haskell。仍在測試我的答案,但似乎要快得多(類似於其他問題)。 – Zeta

回答

26

不,但您可以使用更好的功能。我不是在談論V.map (+64),這會讓事情變得更快,但約nTimes。我們有三個候選人已經做nTimes做:

{-# INLINE nTimesFoldr #-} 
nTimesFoldr :: Int -> (a -> a) -> a -> a  
nTimesFoldr n f x = foldr (.) id (replicate n f) $ x 

{-# INLINE nTimesIterate #-} 
nTimesIterate :: Int -> (a -> a) -> a -> a  
nTimesIterate n f x = iterate f x !! n 

{-# INLINE nTimesTail #-} 
nTimesTail :: Int -> (a -> a) -> a -> a  
nTimesTail n f = go n 
    where 
    {-# INLINE go #-} 
    go n x | n <= 0 = x 
    go n x   = go (n - 1) (f x) 

所有版本大約需要8秒,相比40秒鐘內你的版本需要。順便說一下,Joachim的版本也需要8秒。請注意,iterate版本在我的系統上佔用更多內存。儘管GHC有unroll plugin,但在過去的五年裏它沒有更新(它使用自定義的說明)。

根本沒有展開?

但是,在我們絕望之前,GHC實際上試圖將所有內容都嵌入其中?讓我們用nTimesTailnTimes 1

module Main where 
import qualified Data.Vector.Unboxed as V 

{-# INLINE incAll #-} 
incAll :: V.Vector Int -> V.Vector Int 
incAll = V.map (+ 1) 

{-# INLINE nTimes #-} 
nTimes :: Int -> (a -> a) -> a -> a  
nTimes n f = go n 
    where 
    {-# INLINE go #-} 
    go n x | n <= 0 = x 
    go n x   = go (n - 1) (f x) 

main :: IO() 
main = do 
    let size = 100000000 :: Int 
    let array = V.replicate size 0 :: V.Vector Int 
    print $ V.sum (nTimes 1 incAll array) 
$ stack ghc --package vector -- -O2 -ddump-simpl -dsuppress-all SO.hs 
main2 = 
    case (runSTRep main3) `cast` ... 
    of _ { Vector ww1_s9vw ww2_s9vx ww3_s9vy -> 
    case $wgo 1 ww1_s9vw ww2_s9vx ww3_s9vy 
    of _ { (# ww5_s9w3, ww6_s9w4, ww7_s9w5 #) -> 

我們可以停在那兒。 $wgo是上面定義的go。即使使用1 GHC也不會展開循環。這是令人不安的,因爲1是一個常數。

拯救模板

但是,唉,它並不是全部。如果C++程序員能夠對編譯時常量進行以下操作,那麼我們應該如此,對吧?

template <int N> 
struct Call{ 
    template <class F, class T> 
    static T call(F f, T && t){ 
     return f(Call<N-1>::call(f,std::forward<T>(t))); 
    } 
}; 
template <> 
struct Call<0>{ 
    template <class F, class T> 
    static T call(F f, T && t){ 
     return t; 
    } 
}; 

果然,我們可以與TemplateHaskell*

-- Times.sh 
{-# LANGUAGE TemplateHaskell #-} 
module Times where 

import Control.Monad (when) 
import Language.Haskell.TH 

nTimesTH :: Int -> Q Exp 
nTimesTH n = do 
    f <- newName "f" 
    x <- newName "x" 

    when (n <= 0) (reportWarning "nTimesTH: argument non-positive") 

    let go k | k <= 0 = VarE x 
     go k   = AppE (VarE f) (go (k - 1)) 
    return $ LamE [VarP f,VarP x] (go n) 

是什麼nTimesTH辦?它會創建一個新函數,其中第一個名稱f將應用於第二個名稱x,總計n次。 n現在需要一個編譯時間常數,它適合我們,因爲循環展開,纔可能與編譯時間常數:

$(nTimesTH 0) = \f x -> x 
$(nTimesTH 1) = \f x -> f x 
$(nTimesTH 2) = \f x -> f (f x) 
$(nTimesTH 3) = \f x -> f (f (f x)) 
... 

是否行得通?它快嗎?與nTimes相比有多快?讓我們嘗試另一個main爲:

-- SO.hs 
{-# LANGUAGE TemplateHaskell #-} 
module Main where 
import Times 
import qualified Data.Vector.Unboxed as V 

{-# INLINE incAll #-} 
incAll :: V.Vector Int -> V.Vector Int 
incAll = V.map (+ 1) 

{-# INLINE nTimes #-} 
nTimes :: Int -> (a -> a) -> a -> a  
nTimes n f = go n 
    where 
    {-# INLINE go #-} 
    go n x | n <= 0 = x 
    go n x   = go (n - 1) (f x) 

main :: IO() 
main = do 
    let size = 100000000 :: Int 
    let array = V.replicate size 0 :: V.Vector Int 
    let vTH = V.sum ($(nTimesTH 64) incAll array) 
    let vNorm = V.sum (nTimes 64 incAll array) 
    print $ vTH == vNorm 
stack ghc --package vector -- -O2 SO.hs && SO.exe +RTS -t 
True 
<<ghc: 52000056768 bytes, 66 GCs, 400034700/800026736 avg/max bytes residency (2 samples), 1527M in use, 0.000 INIT (0.000 elapsed), 8.875 MUT (9.119 elapsed), 0.000 GC (0.094 elapsed) :ghc>> 

它得到正確的結果。它有多快?讓我們再次用另一個main

main :: IO() 
main = do 
    let size = 100000000 :: Int 
    let array = V.replicate size 0 :: V.Vector Int 
    print $ V.sum ($(nTimesTH 64) incAll array) 
 800,048,112 bytes allocated in the heap           
      4,352 bytes copied during GC            
      42,664 bytes maximum residency (1 sample(s))        
      18,776 bytes maximum slop             
      764 MB total memory in use (0 MB lost due to fragmentation)    

            Tot time (elapsed) Avg pause Max pause   
    Gen 0   1 colls,  0 par 0.000s 0.000s  0.0000s 0.0000s   
    Gen 1   1 colls,  0 par 0.000s 0.049s  0.0488s 0.0488s   

    INIT time 0.000s ( 0.000s elapsed)           
    MUT  time 0.172s ( 0.221s elapsed)           
    GC  time 0.000s ( 0.049s elapsed)           
    EXIT time 0.000s ( 0.049s elapsed)           
    Total time 0.188s ( 0.319s elapsed)           

    %GC  time  0.0% (15.3% elapsed)           

    Alloc rate 4,654,825,378 bytes per MUT second         

    Productivity 100.0% of total user, 58.7% of total elapsed   

好,比較,爲8秒。因此,對於TL; DR:如果您有編譯時常量,並且您想基於該常量創建和/或修改您的代碼,請考慮模板Haskell。

*請注意,這是我寫的第一個模板Haskell代碼。小心使用。不要使用太大的n,否則你最終可能會遇到混亂的功能。

+2

注意:解決方案是[代碼審查](https://codereview.stackexchange.com/questions/155144/execute-a-function-n-times-where-n-is-known-at-compile-time )。 – Zeta

+0

嘿剛回來讓你知道這是在大多數方面的輝煌答案,謝謝。 – MaiaVictor

4

你可以寫

{-# INLINE nTimes #-} 
nTimes :: Int -> (a -> a) -> a -> a 
nTimes n f x = go n 
    where go 0 = x 
     go n = f (go (n-1)) 

和GHC會內聯nTimes,並有可能專門遞歸go您的特定參數incAllarray,但它不會展開循環。

+0

啊,很爛,謝謝。 – MaiaVictor

14

Andres已經告訴過我一個小知道的技巧,在那裏你可以通過使用類型類實際獲得GHC內聯遞歸函數。

這個想法是,而不是寫一個函數,通常你在一個值上執行結構遞歸。您可以使用類型類定義函數,並對類型參數執行結構遞歸。在這個例子中,類型級自然數。

由於每次遞歸調用的類型不同,GHC會高興地嵌入每個遞歸調用並生成高效的代碼。

我沒有對此進行基準測試或看看核心,但它明顯更快。

{-# LANGUAGE DataKinds #-} 
{-# LANGUAGE KindSignatures #-} 
{-# LANGUAGE PolyKinds #-} 
{-# LANGUAGE ScopedTypeVariables #-} 
module Main where 

import qualified Data.Vector.Unboxed as V 

data Proxy a = Proxy 

{-# INLINE incAll #-} 
incAll :: V.Vector Int -> V.Vector Int 
incAll = V.map (+ 1) 

oldNTimes :: Int -> (a -> a) -> a -> a 
oldNTimes 0 f x = x 
oldNTimes n f x = f (oldNTimes (n-1) f x) 

-- New definition 

data N = Z | S N 

class Unroll (n :: N) where 
    nTimes :: Proxy n -> (a -> a) -> a -> a 

instance Unroll Z where 
    nTimes _ f x = x 

instance Unroll n => Unroll (S n) where 
    nTimes p f x = 
     let Proxy :: Proxy (S n) = p 
     in f (nTimes (Proxy :: Proxy n) f x) 

main :: IO() 
main = do 
    let size = 100000000 :: Int 
    let array = V.replicate size 0 :: V.Vector Int 
    print $ V.sum (nTimes (Proxy :: Proxy (S (S (S (S (S (S (S (S (S (S (S Z)))))))))))) incAll array) 
    print $ V.sum (oldNTimes 11 incAll array) 
+0

不錯,雖然如果你想使用'nTimes 64','Proxy :: Proxy(S(S(S(S ...(SZ)...)'這個詞會比較有趣,我會用它來類型級別的算術,但。有些像'代理(十:*:六:+:四)'。 – Zeta

+0

我仍然無法得到這些類型類的編程惡作劇,任何明確表示我的人都是這樣的人。 – MaiaVictor