Skip to content

Commit 19ce7ff

Browse files
committed
Add AsyncParentKill exception and enhance async handling in AIO,
1 parent 60a6c48 commit 19ce7ff

File tree

3 files changed

+113
-143
lines changed

3 files changed

+113
-143
lines changed

hls-graph/src/Development/IDE/Graph/Internal/Action.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ actionFork act k = do
8181

8282
isAsyncException :: SomeException -> Bool
8383
isAsyncException e
84+
| Just (_ :: SomeAsyncException) <- fromException e = True
8485
| Just (_ :: AsyncCancelled) <- fromException e = True
8586
| Just (_ :: AsyncException) <- fromException e = True
87+
| Just (_ :: AsyncParentKill) <- fromException e = True
8688
| Just (_ :: ExitCode) <- fromException e = True
8789
| otherwise = False
8890

hls-graph/src/Development/IDE/Graph/Internal/Database.hs

Lines changed: 89 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,24 @@
88
{-# LANGUAGE RecordWildCards #-}
99
{-# LANGUAGE TypeFamilies #-}
1010

11-
module Development.IDE.Graph.Internal.Database (compute, newDatabase, incDatabase, build, getDirtySet, getKeysAndVisitAge) where
11+
module Development.IDE.Graph.Internal.Database (compute, newDatabase, incDatabase, build, getDirtySet, getKeysAndVisitAge, AsyncParentKill(..)) where
1212

1313
import Prelude hiding (unzip)
1414

1515
import Control.Concurrent.Async
1616
import Control.Concurrent.Extra
17-
import Control.Concurrent.STM.Stats (STM, atomically,
17+
import Control.Concurrent.STM.Stats (STM, TVar, atomically,
1818
atomicallyNamed,
1919
modifyTVar', newTVarIO,
20-
readTVarIO)
20+
readTVar, readTVarIO,
21+
retry)
2122
import Control.Exception
2223
import Control.Monad
2324
import Control.Monad.IO.Class (MonadIO (liftIO))
2425
import Control.Monad.Trans.Class (lift)
2526
import Control.Monad.Trans.Reader
2627
import qualified Control.Monad.Trans.State.Strict as State
2728
import Data.Dynamic
28-
import Data.Either
2929
import Data.Foldable (for_, traverse_)
3030
import Data.IORef.Extra
3131
import Data.Maybe
@@ -39,8 +39,9 @@ import Development.IDE.Graph.Internal.Types
3939
import qualified Focus
4040
import qualified ListT
4141
import qualified StmContainers.Map as SMap
42-
import System.IO.Unsafe
4342
import System.Time.Extra (duration, sleep)
43+
import UnliftIO (MonadUnliftIO (withRunInIO))
44+
import qualified UnliftIO.Exception as UE
4445

4546
#if MIN_VERSION_base(4,19,0)
4647
import Data.Functor (unzip)
@@ -78,7 +79,7 @@ incDatabase db Nothing = do
7879
updateDirty :: Monad m => Focus.Focus KeyDetails m ()
7980
updateDirty = Focus.adjust $ \(KeyDetails status rdeps) ->
8081
let status'
81-
| Running _ _ _ x <- status = Dirty x
82+
| Running _ x <- status = Dirty x
8283
| Clean x <- status = Dirty (Just x)
8384
| otherwise = status
8485
in KeyDetails status' rdeps
@@ -88,11 +89,8 @@ build
8889
=> Database -> Stack -> f key -> IO (f Key, f value)
8990
-- build _ st k | traceShow ("build", st, k) False = undefined
9091
build db stack keys = do
91-
built <- runAIO $ do
92-
built <- builder db stack (fmap newKey keys)
93-
case built of
94-
Left clean -> return clean
95-
Right dirty -> liftIO dirty
92+
step <- readTVarIO $ databaseStep db
93+
!built <- runAIO step $ builder db stack (fmap newKey keys)
9694
let (ids, vs) = unzip built
9795
pure (ids, fmap (asV . resultValue) vs)
9896
where
@@ -102,44 +100,41 @@ build db stack keys = do
102100
-- | Build a list of keys and return their results.
103101
-- If none of the keys are dirty, we can return the results immediately.
104102
-- Otherwise, a blocking computation is returned *which must be evaluated asynchronously* to avoid deadlock.
105-
builder
106-
:: Traversable f => Database -> Stack -> f Key -> AIO (Either (f (Key, Result)) (IO (f (Key, Result))))
103+
builder :: (Traversable f) => Database -> Stack -> f Key -> AIO (f (Key, Result))
107104
-- builder _ st kk | traceShow ("builder", st,kk) False = undefined
108-
builder db@Database{..} stack keys = withRunInIO $ \(RunInIO run) -> do
109-
-- Things that I need to force before my results are ready
110-
toForce <- liftIO $ newTVarIO []
111-
current <- liftIO $ readTVarIO databaseStep
112-
results <- liftIO $ for keys $ \id ->
113-
-- Updating the status of all the dependencies atomically is not necessary.
114-
-- Therefore, run one transaction per dep. to avoid contention
115-
atomicallyNamed "builder" $ do
116-
-- Spawn the id if needed
117-
status <- SMap.lookup id databaseValues
118-
val <- case viewDirty current $ maybe (Dirty Nothing) keyStatus status of
119-
Clean r -> pure r
120-
Running _ force val _
121-
| memberStack id stack -> throw $ StackException stack
122-
| otherwise -> do
123-
modifyTVar' toForce (Wait force :)
124-
pure val
125-
Dirty s -> do
126-
let act = run (refresh db stack id s)
127-
(force, val) = splitIO (join act)
128-
SMap.focus (updateStatus $ Running current force val s) id databaseValues
129-
modifyTVar' toForce (Spawn force:)
130-
pure val
131-
132-
pure (id, val)
133-
134-
toForceList <- liftIO $ readTVarIO toForce
135-
let waitAll = run $ waitConcurrently_ toForceList
136-
case toForceList of
137-
[] -> return $ Left results
138-
_ -> return $ Right $ do
139-
waitAll
140-
pure results
141-
142-
105+
builder db stack keys = do
106+
keyWaits <- for keys $ \k -> builderOne db stack k
107+
!res <- for keyWaits $ \(k, waitR) -> do
108+
!v<- liftIO waitR
109+
return (k, v)
110+
return res
111+
112+
builderOne :: Database -> Stack -> Key -> AIO (Key, IO Result)
113+
builderOne db@Database {..} stack id = UE.mask_ $ do
114+
current <- liftIO $ readTVarIO databaseStep
115+
(k, registerWaitResult) <- liftIO $ atomicallyNamed "builder" $ do
116+
-- Spawn the id if needed
117+
status <- SMap.lookup id databaseValues
118+
val <-
119+
let refreshRsult s = do
120+
let act =
121+
asyncWithCleanUp $
122+
refresh db stack id s
123+
`UE.onException` liftIO (atomicallyNamed "builder - onException" (SMap.focus updateDirty id databaseValues))
124+
125+
SMap.focus (updateStatus $ Running current s) id databaseValues
126+
return act
127+
in case viewDirty current $ maybe (Dirty Nothing) keyStatus status of
128+
Dirty mbr -> refreshRsult mbr
129+
Running step _mbr
130+
| step /= current -> error $ "Inconsistent database state: key " ++ show id ++ " is marked Running at step " ++ show step ++ " but current step is " ++ show current
131+
| memberStack id stack -> throw $ StackException stack
132+
| otherwise -> retry
133+
Clean r -> pure . pure . pure $ r
134+
-- force here might contains async exceptions from previous runs
135+
pure (id, val)
136+
waitR <- registerWaitResult
137+
return (k, waitR)
143138
-- | isDirty
144139
-- only dirty when it's build time is older than the changed time of one of its dependencies
145140
isDirty :: Foldable t => Result -> t (a, Result) -> Bool
@@ -155,41 +150,37 @@ isDirty me = any (\(_,dep) -> resultBuilt me < resultChanged dep)
155150
refreshDeps :: KeySet -> Database -> Stack -> Key -> Result -> [KeySet] -> AIO Result
156151
refreshDeps visited db stack key result = \case
157152
-- no more deps to refresh
158-
[] -> liftIO $ compute db stack key RunDependenciesSame (Just result)
153+
[] -> compute' db stack key RunDependenciesSame (Just result)
159154
(dep:deps) -> do
160155
let newVisited = dep <> visited
161156
res <- builder db stack (toListKeySet (dep `differenceKeySet` visited))
162-
case res of
163-
Left res -> if isDirty result res
157+
if isDirty result res
164158
-- restart the computation if any of the deps are dirty
165-
then liftIO $ compute db stack key RunDependenciesChanged (Just result)
159+
then compute' db stack key RunDependenciesChanged (Just result)
166160
-- else kick the rest of the deps
167161
else refreshDeps newVisited db stack key result deps
168-
Right iores -> do
169-
res <- liftIO iores
170-
if isDirty result res
171-
then liftIO $ compute db stack key RunDependenciesChanged (Just result)
172-
else refreshDeps newVisited db stack key result deps
173-
174-
-- | Refresh a key:
175-
refresh :: Database -> Stack -> Key -> Maybe Result -> AIO (IO Result)
162+
163+
164+
-- refresh :: Database -> Stack -> Key -> Maybe Result -> IO Result
176165
-- refresh _ st k _ | traceShow ("refresh", st, k) False = undefined
166+
refresh :: Database -> Stack -> Key -> Maybe Result -> AIO Result
177167
refresh db stack key result = case (addStack key stack, result) of
178168
(Left e, _) -> throw e
179-
(Right stack, Just me@Result{resultDeps = ResultDeps deps}) -> asyncWithCleanUp $ refreshDeps mempty db stack key me (reverse deps)
180-
(Right stack, _) ->
181-
asyncWithCleanUp $ liftIO $ compute db stack key RunDependenciesChanged result
169+
(Right stack, Just me@Result{resultDeps = ResultDeps deps}) -> refreshDeps mempty db stack key me (reverse deps)
170+
(Right stack, _) -> compute' db stack key RunDependenciesChanged result
182171

172+
compute' :: Database -> Stack -> Key -> RunMode -> Maybe Result -> AIO Result
173+
compute' db stack key mode result = liftIO $ compute db stack key mode result
183174
-- | Compute a key.
184175
compute :: Database -> Stack -> Key -> RunMode -> Maybe Result -> IO Result
185176
-- compute _ st k _ _ | traceShow ("compute", st, k) False = undefined
186177
compute db@Database{..} stack key mode result = do
187178
let act = runRule databaseRules key (fmap resultData result) mode
188-
deps <- newIORef UnknownDeps
179+
deps <- liftIO $ newIORef UnknownDeps
189180
(execution, RunResult{..}) <-
190-
duration $ runReaderT (fromAction act) $ SAction db deps stack
191-
curStep <- readTVarIO databaseStep
192-
deps <- readIORef deps
181+
liftIO $ duration $ runReaderT (fromAction act) $ SAction db deps stack
182+
curStep <- liftIO $ readTVarIO databaseStep
183+
deps <- liftIO $ readIORef deps
193184
let lastChanged = maybe curStep resultChanged result
194185
let lastBuild = maybe curStep resultBuilt result
195186
-- changed time is always older than or equal to build time
@@ -212,12 +203,12 @@ compute db@Database{..} stack key mode result = do
212203
-- If an async exception strikes before the deps have been recorded,
213204
-- we won't be able to accurately propagate dirtiness for this key
214205
-- on the next build.
215-
void $
206+
liftIO $ void $
216207
updateReverseDeps key db
217208
(getResultDepsDefault mempty previousDeps)
218209
deps
219210
_ -> pure ()
220-
atomicallyNamed "compute and run hook" $ do
211+
liftIO $ atomicallyNamed "compute and run hook" $ do
221212
runHook
222213
SMap.focus (updateStatus $ Clean res) key databaseValues
223214
pure res
@@ -247,18 +238,6 @@ getKeysAndVisitAge db = do
247238
getAge Result{resultVisited = Step s} = curr - s
248239
return keysWithVisitAge
249240
--------------------------------------------------------------------------------
250-
-- Lazy IO trick
251-
252-
data Box a = Box {fromBox :: a}
253-
254-
-- | Split an IO computation into an unsafe lazy value and a forcing computation
255-
splitIO :: IO a -> (IO (), a)
256-
splitIO act = do
257-
let act2 = Box <$> act
258-
let res = unsafePerformIO act2
259-
(void $ evaluate res, fromBox res)
260-
261-
--------------------------------------------------------------------------------
262241
-- Reverse dependencies
263242

264243
-- | Update the reverse dependencies of an Id
@@ -301,14 +280,29 @@ transitiveDirtySet database = flip State.execStateT mempty . traverse_ loop
301280

302281
-- | A simple monad to implement cancellation on top of 'Async',
303282
-- generalizing 'withAsync' to monadic scopes.
304-
newtype AIO a = AIO { unAIO :: ReaderT (IORef [Async ()]) IO a }
283+
newtype AIO a = AIO { unAIO :: ReaderT (TVar [Async ()]) IO a }
305284
deriving newtype (Applicative, Functor, Monad, MonadIO)
306285

286+
data AsyncParentKill = AsyncParentKill ThreadId Step
287+
deriving (Show, Eq)
288+
289+
instance Exception AsyncParentKill where
290+
toException = asyncExceptionToException
291+
fromException = asyncExceptionFromException
292+
307293
-- | Run the monadic computation, cancelling all the spawned asyncs if an exception arises
308-
runAIO :: AIO a -> IO a
309-
runAIO (AIO act) = do
310-
asyncs <- newIORef []
311-
runReaderT act asyncs `onException` cleanupAsync asyncs
294+
runAIO :: Step -> AIO a -> IO a
295+
runAIO s (AIO act) = do
296+
asyncsRef <- newTVarIO []
297+
-- Log the exact exception (including async exceptions) before cleanup,
298+
-- then rethrow to preserve previous semantics.
299+
runReaderT act asyncsRef `onException` do
300+
asyncs <- atomically $ do
301+
r <- readTVar asyncsRef
302+
modifyTVar' asyncsRef $ const []
303+
return r
304+
tid <- myThreadId
305+
cleanupAsync asyncs tid s
312306

313307
-- | Like 'async' but with built-in cancellation.
314308
-- Returns an IO action to wait on the result.
@@ -319,27 +313,25 @@ asyncWithCleanUp act = do
319313
-- mask to make sure we keep track of the spawned async
320314
liftIO $ uninterruptibleMask $ \restore -> do
321315
a <- async $ restore io
322-
atomicModifyIORef'_ st (void a :)
316+
atomically $ modifyTVar' st (void a :)
323317
return $ wait a
324318

325319
unliftAIO :: AIO a -> AIO (IO a)
326320
unliftAIO act = do
327321
st <- AIO ask
328322
return $ runReaderT (unAIO act) st
329323

330-
newtype RunInIO = RunInIO (forall a. AIO a -> IO a)
324+
instance MonadUnliftIO AIO where
325+
withRunInIO k = do
326+
st <- AIO ask
327+
liftIO $ k (\aio -> runReaderT (unAIO aio) st)
331328

332-
withRunInIO :: (RunInIO -> AIO b) -> AIO b
333-
withRunInIO k = do
334-
st <- AIO ask
335-
k $ RunInIO (\aio -> runReaderT (unAIO aio) st)
336-
337-
cleanupAsync :: IORef [Async a] -> IO ()
329+
cleanupAsync :: [Async a] -> ThreadId -> Step -> IO ()
338330
-- mask to make sure we interrupt all the asyncs
339-
cleanupAsync ref = uninterruptibleMask $ \unmask -> do
340-
asyncs <- atomicModifyIORef' ref ([],)
331+
cleanupAsync asyncs tid step = uninterruptibleMask $ \unmask -> do
341332
-- interrupt all the asyncs without waiting
342-
mapM_ (\a -> throwTo (asyncThreadId a) AsyncCancelled) asyncs
333+
-- mapM_ (\a -> throwTo (asyncThreadId a) AsyncCancelled) asyncs
334+
mapM_ (\a -> throwTo (asyncThreadId a) $ AsyncParentKill tid step) asyncs
343335
-- Wait until all the asyncs are done
344336
-- But if it takes more than 10 seconds, log to stderr
345337
unless (null asyncs) $ do
@@ -348,32 +340,3 @@ cleanupAsync ref = uninterruptibleMask $ \unmask -> do
348340
traceM "cleanupAsync: waiting for asyncs to finish"
349341
withAsync warnIfTakingTooLong $ \_ ->
350342
mapM_ waitCatch asyncs
351-
352-
data Wait
353-
= Wait {justWait :: !(IO ())}
354-
| Spawn {justWait :: !(IO ())}
355-
356-
fmapWait :: (IO () -> IO ()) -> Wait -> Wait
357-
fmapWait f (Wait io) = Wait (f io)
358-
fmapWait f (Spawn io) = Spawn (f io)
359-
360-
waitOrSpawn :: Wait -> IO (Either (IO ()) (Async ()))
361-
waitOrSpawn (Wait io) = pure $ Left io
362-
waitOrSpawn (Spawn io) = Right <$> async io
363-
364-
waitConcurrently_ :: [Wait] -> AIO ()
365-
waitConcurrently_ [] = pure ()
366-
waitConcurrently_ [one] = liftIO $ justWait one
367-
waitConcurrently_ many = do
368-
ref <- AIO ask
369-
-- spawn the async computations.
370-
-- mask to make sure we keep track of all the asyncs.
371-
(asyncs, syncs) <- liftIO $ uninterruptibleMask $ \unmask -> do
372-
waits <- liftIO $ traverse (waitOrSpawn . fmapWait unmask) many
373-
let (syncs, asyncs) = partitionEithers waits
374-
liftIO $ atomicModifyIORef'_ ref (asyncs ++)
375-
return (asyncs, syncs)
376-
-- work on the sync computations
377-
liftIO $ sequence_ syncs
378-
-- wait for the async computations before returning
379-
liftIO $ traverse_ wait asyncs

0 commit comments

Comments
 (0)