Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve parallel template #2809

Merged
merged 1 commit into from
Jun 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 91 additions & 86 deletions src/Parallel/ParallelTemplate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@ module Parallel.ParallelTemplate
compileArgsNumWorkers,
compileArgsCompileNode,
compileArgsPreProcess,
compilationError,
compile,
)
where

import Control.Concurrent (ThreadId)
import Control.Concurrent.STM.TVar (stateTVar)
import Control.Exception qualified as GHC
import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Effectful.Concurrent
Expand All @@ -49,7 +47,6 @@ data CompilationState nodeId compiledProof = CompilationState
_compilationPending :: HashMap nodeId (HashSet nodeId),
_compilationStartedNum :: Natural,
_compilationFinishedNum :: Natural,
_compilationError :: Maybe JuvixError,
_compilationTotalNum :: Natural
}

Expand All @@ -67,14 +64,31 @@ newtype CompileQueue nodeId = CompileQueue
{ _compileQueue :: TBQueue nodeId
}

data LogQueueItem
= LogQueueItem LogItem
| -- | no more log items will be handled after this
LogQueueClose

newtype Logs = Logs
{ _logQueue :: TQueue LogItem
{ _logQueue :: TQueue LogQueueItem
}

newtype NodesIndex nodeId node = NodesIndex
{ _nodesIndex :: HashMap nodeId node
}

data Task nodeId node = Task
{ _taskNum :: Natural,
_taskTotal :: Natural,
_taskNodeId :: nodeId,
_taskNode :: node
}

data Finished
= -- | All modules have started compilation. They might still be compiling
FinishedNoPending
| FinishedPending

makeLenses ''Logs
makeLenses ''NodesIndex
makeLenses ''CompileQueue
Expand All @@ -89,16 +103,11 @@ instance (Show nodeId, Pretty nodeId) => Pretty (Dependencies nodeId) where
| (from, deps) <- HashMap.toList (d ^. dependenciesTable)
]

data Finished
= FinishedOk
| FinishedError JuvixError
| FinishedNot

compilationStateFinished :: CompilationState nodeId compileProof -> Finished
compilationStateFinished CompilationState {..}
| Just err <- _compilationError = FinishedError err
| _compilationFinishedNum == _compilationTotalNum = FinishedOk
| otherwise = FinishedNot
| _compilationStartedNum == _compilationTotalNum = FinishedNoPending
| _compilationStartedNum > _compilationTotalNum = impossible
| otherwise = FinishedPending

addCompiledModule ::
forall nodeId proof.
Expand Down Expand Up @@ -156,7 +165,6 @@ compile args@CompileArgs {..} = do
{ _compilationStartedNum = 0,
_compilationFinishedNum = 0,
_compilationTotalNum = numMods,
_compilationError = Nothing,
_compilationPending = deps ^. dependenciesTable,
_compilationState = mempty
}
Expand All @@ -169,58 +177,75 @@ compile args@CompileArgs {..} = do
. runReader deps
. crashOnError
$ do
let newThread ::
forall r' a.
(Members '[Concurrent] r') =>
Sem r' a ->
Sem r' ()
newThread m = void . forkFinally m $ \case
Left err -> GHC.throw err
Right {} -> return ()
withAsync handleLogs $ \_logHandler -> do
let useAsync = False
if
| useAsync ->
replicateConcurrently_ _compileArgsNumWorkers $
lookForWork @nodeId @node @compileProof
| otherwise ->
replicateM_ _compileArgsNumWorkers
. newThread
$ lookForWork @nodeId @node @compileProof
waitForWorkers @nodeId @compileProof
withAsync handleLogs $ \logHandler -> do
replicateConcurrently_ _compileArgsNumWorkers $
lookForWork @nodeId @node @compileProof
wait logHandler
(^. compilationState) <$> readTVarIO varCompilationState

handleLogs :: (Members '[ProgressLog, Concurrent, Reader Logs] r) => Sem r ()
handleLogs = do
x <- asks (^. logQueue) >>= atomically . readTQueue
progressLog x
handleLogs
case x of
LogQueueClose -> return ()
LogQueueItem l -> do
progressLog l
handleLogs

waitForWorkers ::
forall nodeId compileProof r.
( Members
getTask ::
forall nodeId (node :: GHCType) compileProof (s :: [Effect]) r.
( Hashable nodeId,
Members
'[ Concurrent,
Reader (TVar (CompilationState nodeId compileProof)),
Error JuvixError,
Reader (CompileArgs s nodeId node compileProof),
Reader (NodesIndex nodeId node),
Reader (CompileQueue nodeId),
Reader Logs
]
r
) =>
Sem r ()
waitForWorkers = do
Logs logs <- ask
Sem r (Maybe (Task nodeId node))
getTask = do
stVar <- ask @(TVar (CompilationState nodeId compileProof))
qq <- asks (^. compileQueue)
cstVar <- ask @(TVar (CompilationState nodeId compileProof))
(finished, noMoreLogs) <- atomically $ do
idx <- ask @(NodesIndex nodeId node)
logs <- ask
args <- ask @(CompileArgs s nodeId node compileProof)
tid <- myThreadId
atomically $ do
finished <- compilationStateFinished <$> readTVar cstVar
noMoreLogs <- isEmptyTQueue logs
return (finished, noMoreLogs)
let waitMore = waitForWorkers @nodeId @compileProof
case finished of
FinishedError err
| noMoreLogs -> throw err
| otherwise -> waitMore
FinishedNot -> waitMore
FinishedOk -> unless noMoreLogs waitMore
case finished of
FinishedNoPending -> return Nothing
FinishedPending -> do
nextModuleId :: nodeId <- readTBQueue qq
let n :: node =
run
. runReader idx
$ getNode nextModuleId
compSt <- readTVar stVar
modifyTVar stVar (over compilationStartedNum succ)
let num = succ (compSt ^. compilationStartedNum)
total = compSt ^. compilationTotalNum
name = annotate (AnnKind KNameTopModule) (pretty ((args ^. compileArgsNodeName) n))
progress :: Doc CodeAnn =
kwBracketL
<> annotate AnnLiteralInteger (pretty num)
<+> kwOf
<+> annotate AnnLiteralInteger (pretty total) <> kwBracketR <> " "
kwCompiling = annotate AnnKeyword "Compiling"
isLast = num == total
logMsg tid logs (progress <> kwCompiling <> " " <> name)
when isLast (logClose logs)
return $
Just
Task
{ _taskNum = num,
_taskTotal = total,
_taskNodeId = nextModuleId,
_taskNode = n
}

lookForWork ::
forall nodeId node compileProof (s :: [Effect]) r.
Expand All @@ -241,30 +266,9 @@ lookForWork ::
) =>
Sem r ()
lookForWork = do
qq <- asks (^. compileQueue)
stVar <- ask @(TVar (CompilationState nodeId compileProof))
logs <- ask
args <- ask @(CompileArgs s nodeId node compileProof)
idx <- ask @(NodesIndex nodeId node)
tid <- myThreadId
nextModule <- atomically $ do
nextModule :: nodeId <- readTBQueue qq
let n :: node = run . runReader idx $ getNode nextModule
name = annotate (AnnKind KNameTopModule) (pretty ((args ^. compileArgsNodeName) n))
compSt <- readTVar stVar
modifyTVar stVar (over compilationStartedNum succ)
let num = compSt ^. compilationStartedNum
total = compSt ^. compilationTotalNum
progress :: Doc CodeAnn =
kwBracketL
<> annotate AnnLiteralInteger (pretty (succ num))
<+> kwOf
<+> annotate AnnLiteralInteger (pretty total) <> kwBracketR <> " "
kwCompiling = annotate AnnKeyword "Compiling"
logMsg tid logs (progress <> kwCompiling <> " " <> name)
return nextModule
compileNode @s @nodeId @node @compileProof nextModule
lookForWork @nodeId @node @compileProof @s @r
whenJustM (getTask @nodeId @node @compileProof @s) $ \Task {..} -> do
compileNode @s @nodeId @node @compileProof _taskNodeId
lookForWork @nodeId @node @compileProof @s @r

getNode ::
forall nodeId node r.
Expand Down Expand Up @@ -295,13 +299,8 @@ compileNode ::
compileNode nodId = do
m :: node <- getNode nodId
compileFun <- asks @(CompileArgs s nodeId node compileProof) (^. compileArgsCompileNode)
st :: TVar (CompilationState nodeId compileProof) <- ask
result :: Either (CallStack, JuvixError) compileProof <-
inject $
tryError @JuvixError (compileFun m)
case result of
Left (_, err) -> atomically (modifyTVar st (set compilationError (Just err)))
Right proof -> registerCompiledModule @nodeId @node @s @compileProof nodId proof
proof :: compileProof <- inject (compileFun m)
registerCompiledModule @nodeId @node @s @compileProof nodId proof

registerCompiledModule ::
forall nodeId node s compileProof r.
Expand All @@ -328,11 +327,17 @@ registerCompiledModule m proof = do
toQueue <- stateTVar mutSt (swap . addCompiledModule deps m proof)
forM_ toQueue (writeTBQueue qq)

logClose :: Logs -> STM ()
logClose (Logs q) = do
STM.writeTQueue q LogQueueClose

logMsg :: ThreadId -> Logs -> Doc CodeAnn -> STM ()
logMsg tid (Logs q) msg = do
STM.writeTQueue
q
LogItem
{ _logItemMessage = msg,
_logItemThreadId = tid
}
( LogQueueItem
LogItem
{ _logItemMessage = msg,
_logItemThreadId = tid
}
)
Loading