2013-06-15 22 views
16

goworker tail-recursive loop pattern wydaje się działać bardzo dobrze do pisania czystego kodu. Jaki byłby odpowiednik napisania takiej pętli dla monady ST? Dokładniej, chcę uniknąć nowej alokacji sterty w iteracjach pętli. Zgaduję, że wymaga on albo CPS transformation lub fixST, aby ponownie zapisać kod tak, że wszystkie wartości, które zmieniają się w pętli, są przekazywane w każdej iteracji, dzięki czemu lokalizacje rejestrów (lub stos w przypadku rozlania) są dostępne dla tych wartości w całej iteracji . Mam uproszczony przykład poniżej (nie spróbować uruchomić go - to będzie prawdopodobnie awarii z winy segmentacji!) Z udziałem funkcję o nazwie findSnakes który ma go pracownika wzór ale zmieniające się wartości państwowe nie są przekazywane przez argumenty akumulatorowe:Pisanie wydajnej pętli iteracyjnej dla monady ST

{-# LANGUAGE BangPatterns #-} 
module Test where 

import Data.Vector.Unboxed.Mutable as MU 
import Data.Vector.Unboxed as U hiding (mapM_) 
import Control.Monad.ST as ST 
import Control.Monad.Primitive (PrimState) 
import Control.Monad as CM (when,forM_) 
import Data.Int 

type MVI1 s = MVector (PrimState (ST s)) Int 

-- function to find previous y 
findYP :: MVI1 s -> Int -> Int -> ST s Int 
findYP fp k offset = do 
       y0 <- MU.unsafeRead fp (k+offset-1) >>= \x -> return $ 1+x 
       y1 <- MU.unsafeRead fp (k+offset+1) 
       if y0 > y1 then return y0 
       else return y1 
{-#INLINE findYP #-} 

findSnakes :: Vector Int32 -> MVI1 s -> Int -> Int -> (Int -> Int -> Int) -> ST s() 
findSnakes a fp !k !ct !op = go 0 k 
    where 
      offset=1+U.length a 
      go x k' 
      | x < ct = do 
       yp <- findYP fp k' offset 
       MU.unsafeWrite fp (k'+offset) (yp + k') 
       go (x+1) (op k' 1) 
      | otherwise = return() 
{-#INLINE findSnakes #-} 

Patrząc na cmm wyjścia w ghc 7.6.1 (z mojej ograniczonej wiedzy cmm - proszę mnie poprawić, jeśli mam to źle), widzę tego rodzaju przepływu połączeń, z pętlą w s1tb_info (który powoduje sprawdzanie alokacji sterty i sterty w każdej iteracji):

findSnakes_info -> a1_r1qd_info -> $wa_r1qc_info (new stack allocation, SpLim check) 
-> s1sy_info -> s1sj_info: if arg > 1 then s1w8_info else R1 (can't figure out 
what that register points to) 

-- I am guessing this one below is for go loop 
s1w8_info -> s1w7_info (big heap allocation, HpLim check) -> s1tb_info: if arg >= 1 
then s1td_info else R1 

s1td_info (big heap allocation, HpLim check) -> if arg >= 1 then s1tb_info 
(a loop) else s1tb_info (after executing a different block of code) 

Domyślam się, że sprawdzenie kodu w postaci arg >= 1 w kodzie cmm ma na celu sprawdzenie, czy zakończyła się pętla go. Jeśli jest to poprawne, wydaje się, że jeśli pętla go nie zostanie przepisana, aby przepuścić yp przez pętlę, alokacja sterty stanie się w pętli dla nowych wartości (domyślam się, że to powoduje, że ta alokacja sterty jest spowodowana przez yp). Jaki byłby skuteczny sposób napisania pętli go w powyższym przykładzie? Domyślam się, że yp musi zostać przekazany jako argument w pętli go lub w równoważny sposób poprzez transformację fixST lub CPS. Nie potrafię wymyślić dobrego sposobu na przepisanie powyższej pętli go, aby usunąć alokacje sterty i doceniam pomoc.

Odpowiedz

3

Przepisałem twoje funkcje, aby uniknąć wyraźnej rekurencji i usunąłem zbędne operacje obliczające przesunięcia. To kompiluje się do znacznie ładniejszego rdzenia niż twoje oryginalne funkcje.

Rdzeń, przy okazji, jest prawdopodobnie lepszym sposobem analizy skompilowanego kodu dla tego rodzaju profilowania. Użyj ghc -ddump-simpl zobaczyć wygenerowany wyjście rdzenia lub narzędzi, takich jak ghc-core

import Control.Monad.Primitive                    
import Control.Monad.ST                      
import Data.Int                        
import qualified Data.Vector.Unboxed.Mutable as M                
import qualified Data.Vector.Unboxed as U                  

type MVI1 s = M.MVector (PrimState (ST s)) Int                

findYP :: MVI1 s -> Int -> ST s Int                                      
findYP fp offset = do                      
    y0 <- M.unsafeRead fp (offset+0)                  
    y1 <- M.unsafeRead fp (offset+2)                  
    return $ max (y0 + 1) y1                     

findSnakes :: U.Vector Int32 -> MVI1 s -> Int -> Int -> (Int -> Int -> Int) -> ST s()                           
findSnakes a fp k0 ct op = U.mapM_ writeAt $ U.iterateN ct (`op` 1) k0          
    where writeAt k = do  
       let offset = U.length a + k                 
       yp <- findYP fp offset                   
       M.unsafeWrite fp (offset + 1) (yp + k) 

      -- or inline findYP manually 
      writeAt k = do 
      let offset = U.length a + k 
      y0 <- M.unsafeRead fp (offset + 0) 
      y1 <- M.unsafeRead fp (offset + 2) 
      M.unsafeWrite fp (offset + 1) (k + max (y0 + 1) y1) 

Ponadto, aby przekazać U.Vector Int32findSnakes tylko obliczyć jego długość i nigdy nie używać a ponownie. Dlaczego nie przekazać bezpośrednio długości?