Skip to content

Commit

Permalink
Implement core transformation let-hoisting (#2076)
Browse files Browse the repository at this point in the history
- Closes #2033 
- Based on #2032
  • Loading branch information
janmasrovira authored May 16, 2023
1 parent 185937f commit 3ed30dd
Show file tree
Hide file tree
Showing 20 changed files with 347 additions and 12 deletions.
5 changes: 5 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ data TransformationId
| DisambiguateNames
| CheckGeb
| CheckExec
| Normalize
| LetFolding
| LambdaFolding
| LetHoisting
| Inlining
| FoldTypeSynonyms
| OptPhaseEval
Expand Down Expand Up @@ -64,6 +66,9 @@ toEvalTransformations = [EtaExpandApps, MatchToCase, NatToPrimInt, IntToPrimInt,
toNormalizeTransformations :: [TransformationId]
toNormalizeTransformations = toEvalTransformations ++ [LetRecLifting, LetFolding, UnrollRecursion]

toVampIRTransformations :: [TransformationId]
toVampIRTransformations = toNormalizeTransformations ++ [Normalize, LetHoisting]

toStrippedTransformations :: [TransformationId]
toStrippedTransformations =
toEvalTransformations ++ [CheckExec, LambdaLetRecLifting, OptPhaseExec, TopEtaExpand, MoveApps, RemoveTypeArgs]
Expand Down
8 changes: 8 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ transformationText = \case
DisambiguateNames -> strDisambiguateNames
CheckGeb -> strCheckGeb
CheckExec -> strCheckExec
Normalize -> strNormalize
LetFolding -> strLetFolding
LambdaFolding -> strLambdaFolding
LetHoisting -> strLetHoisting
Inlining -> strInlining
FoldTypeSynonyms -> strFoldTypeSynonyms
OptPhaseEval -> strOptPhaseEval
Expand All @@ -97,6 +99,9 @@ transformation = P.choice [symbol (transformationText t) $> t | t <- allElements
allStrings :: [Text]
allStrings = map transformationLikeText allTransformationLikeIds

strLetHoisting :: Text
strLetHoisting = "let-hoisting"

strEvalPipeline :: Text
strEvalPipeline = "pipeline-eval"

Expand Down Expand Up @@ -160,6 +165,9 @@ strCheckGeb = "check-geb"
strCheckExec :: Text
strCheckExec = "check-exec"

strNormalize :: Text
strNormalize = "normalize"

strLetFolding :: Text
strLetFolding = "let-folding"

Expand Down
7 changes: 7 additions & 0 deletions src/Juvix/Compiler/Core/Extra/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Data.Functor.Identity
import Data.List.NonEmpty qualified as NonEmpty
import Juvix.Compiler.Core.Data.BinderList (BinderList)
import Juvix.Compiler.Core.Info qualified as Info
import Juvix.Compiler.Core.Info.NameInfo (setInfoName)
import Juvix.Compiler.Core.Language
import Polysemy.Input

Expand All @@ -17,6 +18,9 @@ import Polysemy.Input
mkVar :: Info -> Index -> Node
mkVar i idx = NVar (Var i idx)

mkVarN :: Text -> Index -> Node
mkVarN name idx = NVar (Var (setInfoName name mempty) idx)

mkVar' :: Index -> Node
mkVar' = mkVar Info.empty

Expand Down Expand Up @@ -71,6 +75,9 @@ mkLet i bi v b = NLet (Let i (LetItem bi v) b)
mkLet' :: Type -> Node -> Node -> Node
mkLet' ty = mkLet Info.empty (mkBinder' ty)

mkLets :: [LetItem] -> Node -> Node
mkLets tvs n = foldl' (\n' itm -> NLet (Let mempty itm n')) n (reverse tvs)

mkLets' :: [(Type, Node)] -> Node -> Node
mkLets' tvs n = foldl' (\n' (ty, v) -> mkLet' ty v n') n (reverse tvs)

Expand Down
12 changes: 12 additions & 0 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,18 @@ substs t = umapN go
| idx > k -> mkVar i (idx - len)
_ -> n

-- | Increase the indices of free variables in the binderType by a given value
shiftBinder :: Int -> Binder -> Binder
shiftBinder = over binderType . shift

-- | Increase the indices of free variables in the item binder and value
shiftLetItem :: Int -> LetItem -> LetItem
shiftLetItem n l =
LetItem
{ _letItemBinder = shiftBinder n (l ^. letItemBinder),
_letItemValue = shift n (l ^. letItemValue)
}

-- | substitute a term t for the free variable with de Bruijn index 0, avoiding
-- variable capture; shifts all free variabes with de Bruijn index > 0 by -1 (as
-- if the topmost binder was removed)
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Language/Nodes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ data ConstantValue
| ConstString !Text
deriving stock (Eq)

-- | Info about a single binder. Associated with Lambda and Pi.
-- | Info about a single binder. Associated with Lambda, Pi, Let, Case or Match.
data Binder' ty = Binder
{ _binderName :: Text,
_binderLocation :: Maybe Location,
Expand Down
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ import Juvix.Compiler.Core.Transformation.FoldTypeSynonyms
import Juvix.Compiler.Core.Transformation.Identity
import Juvix.Compiler.Core.Transformation.IntToPrimInt
import Juvix.Compiler.Core.Transformation.LambdaLetRecLifting
import Juvix.Compiler.Core.Transformation.LetHoisting
import Juvix.Compiler.Core.Transformation.MatchToCase
import Juvix.Compiler.Core.Transformation.MoveApps
import Juvix.Compiler.Core.Transformation.NaiveMatchToCase qualified as Naive
import Juvix.Compiler.Core.Transformation.NatToPrimInt
import Juvix.Compiler.Core.Transformation.Normalize
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
Expand Down Expand Up @@ -59,8 +61,10 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
DisambiguateNames -> return . disambiguateNames
CheckGeb -> mapError (JuvixError @CoreError) . checkGeb
CheckExec -> mapError (JuvixError @CoreError) . checkExec
Normalize -> return . normalize
LetFolding -> return . letFolding
LambdaFolding -> return . lambdaFolding
LetHoisting -> return . letHoisting
Inlining -> inlining
FoldTypeSynonyms -> return . foldTypeSynonyms
OptPhaseEval -> Phase.Eval.optimize
Expand Down
121 changes: 121 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/LetHoisting.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
-- Moves al let expressions at the top, just after the top lambdas. This
-- transformation assumes:
-- - There are no LetRecs, Lambdas (other than the ones at the top), nor Match.
-- - Case nodes do not have binders.
-- - All variables reference either a lambda or a let.
-- - All let and lambda binders have type Int.
-- - Let nodes do not appear under Pi binders.
module Juvix.Compiler.Core.Transformation.LetHoisting
( module Juvix.Compiler.Core.Transformation.LetHoisting,
module Juvix.Compiler.Core.Transformation.Base,
)
where

import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Data.InfoTableBuilder
import Juvix.Compiler.Core.Extra.Recursors.Map.Named
import Juvix.Compiler.Core.Extra.Utils
import Juvix.Compiler.Core.Pretty
import Juvix.Compiler.Core.Transformation.Base

data LItem = LItem
{ _itemLet :: LetItem,
_itemLevel :: Level,
_itemName :: Text,
_itemSymbol :: Symbol
}

makeLenses ''LItem

-- | `LItem` indexed by Symbol
type LetsTable = HashMap Symbol (Indexed LItem)

mkLetsTable :: [Indexed LItem] -> LetsTable
mkLetsTable l = HashMap.fromList [(i ^. indexedThing . itemSymbol, i) | i <- l]

letHoisting :: InfoTable -> InfoTable
letHoisting = run . mapT' (const letHoist)

letHoist :: forall r. Members '[InfoTableBuilder] r => Node -> Sem r Node
letHoist n = do
let (topLambdas, body) = unfoldLambdas n
(l, body') <- runReader @[Symbol] [] (runOutputList @LItem (removeLets body))
let il = indexFrom 0 l
tbl = mkLetsTable il
nlets = length il
mkLetItem :: Indexed LItem -> LetItem
mkLetItem i = shiftLetItem (i ^. indexedIx) (i ^. indexedThing . itemLet)
letItems = map mkLetItem il
body'' = substPlaceholders tbl (mkLets letItems (shift nlets body'))
return (reLambdas topLambdas body'')

-- | Removes every Let node and replaces references to it with a unique symbol.
removeLets :: forall r. Members '[InfoTableBuilder, Output LItem, Reader [Symbol]] r => Node -> Sem r Node
removeLets = go mempty
where
go :: BinderList Binder -> Node -> Sem r Node
go bl = dmapLRM' (bl, f)
f ::
BinderList Binder ->
Node ->
Sem r Recur
f bl = \case
NVar v
| v ^. varIndex < length bl -> do
End . mkIdent' . (!! (v ^. varIndex)) <$> ask
| otherwise -> return . End . NVar . shiftVar (-length bl) $ v
NLet l -> do
let _itemLevel = length bl
_itemSymbol <- freshSymbol
-- note that the binder does not need to be hoisted because it is
-- assumed to have type Int
let bi = l ^. letItem . letItemBinder
value' <- go bl (l ^. letItem . letItemValue)
output
LItem
{ _itemLet = LetItem bi value',
_itemName = bi ^. binderName,
_itemSymbol,
_itemLevel
}
r <- local (_itemSymbol :) (go (BL.cons bi bl) (l ^. letBody))
return (End r)
other -> return (Recur other)

-- | Replaces the placeholders with variables that point to the hoisted let.
substPlaceholders :: LetsTable -> Node -> Node
substPlaceholders tbl = dmapN go
where
go :: Level -> Node -> Node
go lvl = \case
NIdt i
| Just (t :: Indexed LItem) <- HashMap.lookup (i ^. identSymbol) tbl ->
mkVarN (t ^. indexedThing . itemName) (lvl - t ^. indexedIx - 1)
m -> m

-- | True if it is of the form λ … λ let a₁ = b₁; … aₙ = bₙ in body;
-- where body does not contain any let.
isLetHoisted :: Node -> Bool
isLetHoisted =
checkBody
. snd
. unfoldLambdas
where
checkBody :: Node -> Bool
checkBody n = isJust . run . runFail $ do
k <- peelLets n
noLets k
peelLets :: Members '[Fail] r => Node -> Sem r Node
peelLets = \case
NLet Let {..} -> do
noLets (_letItem ^. letItemValue)
peelLets _letBody
n -> return n
noLets :: forall r. Members '[Fail] r => Node -> Sem r ()
noLets = walk go
where
go :: Node -> Sem r ()
go = \case
NLet {} -> fail
_ -> return ()
4 changes: 0 additions & 4 deletions src/Juvix/Compiler/Core/Transformation/NaiveMatchToCase.hs
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,6 @@ compileMatchBranch (Indexed branchNum br) = do
(auxiliaryBindersNum + patternBindersNum' + patternsNum + branchNum)
(br ^. matchBranchBody)

-- | Increase the indices of free variables in the binderTyped by a given value
shiftBinder :: Index -> Binder -> Binder
shiftBinder idx = over binderType (shift idx)

-- | Make a sequence of nested lets from a list of binders / value pairs. The
-- indices of free variables in binder types are shifted by the sum of
-- `baseShift` and the number of lets that have already been added in the
Expand Down
15 changes: 15 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Normalize.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module Juvix.Compiler.Core.Transformation.Normalize where

import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Core.Normalizer qualified as Normalizer
import Juvix.Compiler.Core.Transformation.Base

normalize :: InfoTable -> InfoTable
normalize tab =
pruneInfoTable $
set identContext (HashMap.singleton sym node) $
set infoIdentifiers (HashMap.singleton sym ii) tab
where
sym = fromJust $ tab ^. infoMain
node = Normalizer.normalize tab (lookupIdentifierNode tab sym)
ii = lookupIdentifierInfo tab sym
4 changes: 3 additions & 1 deletion test/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Core.Normalize qualified as Normalize
import Core.Print qualified as Print
import Core.Recursor qualified as Rec
import Core.Transformation qualified as Transformation
import Core.VampIR qualified as VampIR

allTests :: TestTree
allTests =
Expand All @@ -19,5 +20,6 @@ allTests =
Transformation.allTests,
Asm.allTests,
Compile.allTests,
Normalize.allTests
Normalize.allTests,
VampIR.allTests
]
4 changes: 3 additions & 1 deletion test/Core/Eval/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import Juvix.Compiler.Core.Pretty
import Juvix.Compiler.Core.Transformation
import Juvix.Compiler.Core.Translation.FromSource

data EvalMode = EvalModePlain | EvalModeJSON
data EvalMode
= EvalModePlain
| EvalModeJSON

data EvalData = EvalData
{ _evalDataInput :: [Text],
Expand Down
30 changes: 25 additions & 5 deletions test/Core/Normalize/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,31 @@ fromTest = mkTest . toTestDescr
root :: Path Abs Dir
root = relToProject $(mkRelDir "tests/VampIR/positive/Core")

toTestDescr :: PosTest -> TestDescr
toTestDescr PosTest {..} =
toTestDescr' ::
( Path Abs File ->
Path Abs File ->
(String -> IO ()) ->
Assertion
) ->
PosTest ->
TestDescr
toTestDescr' assertion PosTest {..} =
let tRoot = root <//> _relDir
file' = tRoot <//> _file
expected' = tRoot <//> _expectedFile
in TestDescr
{ _testName = _name,
_testRoot = tRoot,
_testAssertion = Steps $ coreNormalizeAssertion file' expected'
_testAssertion = Steps $ assertion file' expected'
}

toTestDescr :: PosTest -> TestDescr
toTestDescr = toTestDescr' coreNormalizeAssertion

allTests :: TestTree
allTests =
testGroup
"JuvixCore positive tests"
"JuvixCore normalize positive tests"
(map (mkTest . toTestDescr) tests)

tests :: [PosTest]
Expand Down Expand Up @@ -171,5 +181,15 @@ tests =
"Test027: type synonyms"
$(mkRelDir ".")
$(mkRelFile "test027.jvc")
$(mkRelFile "data/test027.json")
$(mkRelFile "data/test027.json"),
PosTest
"Test028: let hoisting"
$(mkRelDir ".")
$(mkRelFile "test028.jvc")
$(mkRelFile "data/test028.json"),
PosTest
"Test029: let hoisting"
$(mkRelDir ".")
$(mkRelFile "test029.jvc")
$(mkRelFile "data/test029.json")
]
8 changes: 8 additions & 0 deletions test/Core/VampIR.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module Core.VampIR where

import Base
import Core.VampIR.LetHoist qualified as LetHoist
import Core.VampIR.Positive qualified as P

allTests :: TestTree
allTests = testGroup "JuvixCore VampIR" [LetHoist.allTests, P.allTests]
Loading

0 comments on commit 3ed30dd

Please sign in to comment.