From 170e3e80a9d522ab3003f7d55b59c27e6b49bfaa Mon Sep 17 00:00:00 2001 From: Jens Nolte <git@queezle.net> Date: Tue, 31 Aug 2021 16:36:40 +0200 Subject: [PATCH] Rework async behavior Remove implicit MonadAsync constraint on `ReaderT ResourceManager IO` to prevent accidental forking in a resource-limited context. Change return type to `Awaitable` to match new MonadResourceManager behavior. Add `runUnlimitedAsync` to run a forking MonadAsync. Add `forkTask` functions for explicit forking in any context. Co-authored-by: Jan Beinke <git@janbeinke.com> --- src/Quasar/Async.hs | 142 ++++++++++++++++++++++++++++----------- src/Quasar/Disposable.hs | 6 ++ src/Quasar/Observable.hs | 41 +++++------ src/Quasar/Timer.hs | 11 ++- test/Quasar/AsyncSpec.hs | 4 +- 5 files changed, 133 insertions(+), 71 deletions(-) diff --git a/src/Quasar/Async.hs b/src/Quasar/Async.hs index 6fdc536..4c013ac 100644 --- a/src/Quasar/Async.hs +++ b/src/Quasar/Async.hs @@ -1,6 +1,7 @@ module Quasar.Async ( -- * Async/await MonadAsync(..), + runUnlimitedAsync, async_, asyncWithUnmask_, @@ -16,6 +17,12 @@ module Quasar.Async ( -- ** Task exceptions CancelTask(..), TaskDisposed(..), + + -- * Unmanaged forking + forkTask, + forkTask_, + forkTaskWithUnmask, + forkTaskWithUnmask_, ) where import Control.Concurrent (ThreadId, forkIOWithUnmask, throwTo) @@ -27,66 +34,121 @@ import Quasar.Disposable import Quasar.Prelude -class (MonadAwait m, MonadResourceManager m, MonadMask m) => MonadAsync m where - async :: m r -> m (Task r) +class (MonadAwait m, MonadResourceManager m) => MonadAsync m where + async :: m r -> m (Awaitable r) async action = asyncWithUnmask ($ action) -- | TODO: Documentation -- -- The action will be run with asynchronous exceptions masked and will be passed an action that can be used unmask. - asyncWithUnmask :: ((forall a. m a -> m a) -> m r) -> m (Task r) + -- + -- TODO change signature to `Awaitable` + asyncWithUnmask :: ((forall a. m a -> m a) -> m r) -> m (Awaitable r) + + +instance MonadAsync m => MonadAsync (ReaderT r m) where + asyncWithUnmask :: MonadAsync m => ((forall b. ReaderT r m b -> ReaderT r m b) -> ReaderT r m a) -> ReaderT r m (Awaitable a) + asyncWithUnmask action = do + x <- ask + lift $ asyncWithUnmask \unmask -> runReaderT (action (liftUnmask unmask)) x + where + -- | Lift an "unmask" action (e.g. from `mask`) into a `ReaderT`. + liftUnmask :: (m a -> m a) -> (ReaderT r m) a -> (ReaderT r m) a + liftUnmask unmask action = do + value <- ask + lift $ unmask $ runReaderT action value + + +async_ :: MonadAsync m => m () -> m () +async_ = void . async + +asyncWithUnmask_ :: MonadAsync m => ((forall a. m a -> m a) -> m ()) -> m () +asyncWithUnmask_ action = void $ asyncWithUnmask action + + -instance MonadAsync (ReaderT ResourceManager IO) where +newtype UnlimitedAsync r = UnlimitedAsync { unUnlimitedAsync :: (ReaderT ResourceManager IO r) } + deriving newtype ( + Functor, + Applicative, + Monad, + MonadIO, + MonadThrow, + MonadCatch, + MonadMask, + MonadFail, + Alternative, + MonadPlus, + MonadAwait, + MonadResourceManager + ) + +instance MonadAsync UnlimitedAsync where asyncWithUnmask action = do resourceManager <- askResourceManager + liftIO $ mask_ $ do + task <- forkTaskWithUnmask (\unmask -> runReaderT (unUnlimitedAsync (action (liftUnmask unmask))) resourceManager) + attachDisposable resourceManager task + pure $ toAwaitable task + where + liftUnmask :: (forall b. IO b -> IO b) -> UnlimitedAsync a -> UnlimitedAsync a + liftUnmask unmask (UnlimitedAsync action) = UnlimitedAsync do + resourceManager <- ask + liftIO $ unmask $ runReaderT action resourceManager - liftIO $ mask_ do - resultVar <- newAsyncVar - threadIdVar <- newEmptyTMVarIO - disposable <- attachDisposeAction resourceManager (disposeTask threadIdVar resultVar) +runUnlimitedAsync :: (MonadResourceManager m) => (forall f. MonadAsync f => f r) -> m r +runUnlimitedAsync action = do + resourceManager <- askResourceManager + liftIO $ runReaderT (unUnlimitedAsync action) resourceManager - onException - do - atomically . putTMVar threadIdVar . Just =<< - forkIOWithUnmask \unmask -> do - result <- try $ catch - do runReaderT (action (liftUnmask unmask)) resourceManager - \CancelTask -> throwIO TaskDisposed - putAsyncVarEither_ resultVar result - -- Thread has completed work, "disarm" the disposable and fire it - void $ atomically $ swapTMVar threadIdVar Nothing - disposeAndAwait disposable +forkTask :: MonadIO m => IO a -> m (Task a) +forkTask action = forkTaskWithUnmask ($ action) - do atomically $ putTMVar threadIdVar Nothing +forkTask_ :: MonadIO m => IO () -> m Disposable +forkTask_ action = toDisposable <$> forkTask action - pure $ Task disposable (toAwaitable resultVar) - where - disposeTask :: TMVar (Maybe ThreadId) -> AsyncVar r -> IO (Awaitable ()) - disposeTask threadIdVar resultVar = mask_ do - -- Blocks until the thread is forked - atomically (swapTMVar threadIdVar Nothing) >>= \case - -- Thread completed or initialization failed - Nothing -> pure () - Just threadId -> throwTo threadId CancelTask +forkTaskWithUnmask :: MonadIO m => ((forall b. IO b -> IO b) -> IO a) -> m (Task a) +forkTaskWithUnmask action = do + liftIO $ mask_ do + resultVar <- newAsyncVar + threadIdVar <- newEmptyTMVarIO - -- Wait for task completion or failure. Tasks must not ignore `CancelTask` or this will hang. - pure $ void (toAwaitable resultVar) `catchAll` const (pure ()) + disposable <- newDisposable $ disposeTask threadIdVar resultVar --- | Lift an "unmask" action (e.g. from `mask`) into a `ReaderT`. -liftUnmask :: (IO a -> IO a) -> (ReaderT r IO) a -> (ReaderT r IO) a -liftUnmask unmask action = do - value <- ask - liftIO $ unmask $ runReaderT action value + onException + do + atomically . putTMVar threadIdVar . Just =<< + forkIOWithUnmask \unmask -> do + result <- try $ catch + do action unmask + \CancelTask -> throwIO TaskDisposed + putAsyncVarEither_ resultVar result -async_ :: MonadAsync m => m () -> m () -async_ = void . async + -- Thread has completed work, "disarm" the disposable and fire it + void $ atomically $ swapTMVar threadIdVar Nothing + disposeAndAwait disposable -asyncWithUnmask_ :: MonadAsync m => ((forall a. m a -> m a) -> m ()) -> m () -asyncWithUnmask_ action = void $ asyncWithUnmask action + do atomically $ putTMVar threadIdVar Nothing + + pure $ Task disposable (toAwaitable resultVar) + where + disposeTask :: TMVar (Maybe ThreadId) -> AsyncVar r -> IO (Awaitable ()) + disposeTask threadIdVar resultVar = mask_ do + -- Blocks until the thread is forked + atomically (swapTMVar threadIdVar Nothing) >>= \case + -- Thread completed or initialization failed + Nothing -> pure () + Just threadId -> throwTo threadId CancelTask + + -- Wait for task completion or failure. Tasks must not ignore `CancelTask` or this will hang. + pure $ void (toAwaitable resultVar) `catchAll` const (pure ()) + +forkTaskWithUnmask_ :: MonadIO m => ((forall b. IO b -> IO b) -> IO ()) -> m Disposable +forkTaskWithUnmask_ action = toDisposable <$> forkTaskWithUnmask action diff --git a/src/Quasar/Disposable.hs b/src/Quasar/Disposable.hs index b6df3c7..1494c18 100644 --- a/src/Quasar/Disposable.hs +++ b/src/Quasar/Disposable.hs @@ -217,6 +217,12 @@ instance (MonadMask m, MonadIO m) => MonadResourceManager (ReaderT ResourceManag askResourceManager = ask localResourceManager = local . const +instance {-# OVERLAPPABLE #-} MonadResourceManager m => MonadResourceManager (ReaderT r m) where + askResourceManager = lift askResourceManager + localResourceManager resourceManager action = do + x <- ask + lift $ localResourceManager resourceManager $ runReaderT action x + onResourceManager :: (HasResourceManager a) => a -> ReaderT ResourceManager m r -> m r onResourceManager target action = runReaderT action (getResourceManager target) diff --git a/src/Quasar/Observable.hs b/src/Quasar/Observable.hs index a0da732..5b302a3 100644 --- a/src/Quasar/Observable.hs +++ b/src/Quasar/Observable.hs @@ -10,7 +10,6 @@ module Quasar.Observable ( ObservableMessage(..), toObservableUpdate, asyncObserve, - asyncObserve_, -- * ObservableVar ObservableVar, @@ -76,7 +75,7 @@ toObservableUpdate (ObservableNotAvailable ex) = throwM ex class IsRetrievable v a | a -> v where - retrieve :: MonadAsync m => a -> m (Task v) + retrieve :: (MonadResourceManager m, MonadAwait m) => a -> m (Awaitable v) retrieveIO :: IsRetrievable v a => a -> IO v retrieveIO x = withOnResourceManager $ await =<< retrieve x @@ -94,10 +93,8 @@ class IsRetrievable v o => IsObservable v o | o -> v where resourceManager <- askResourceManager bracketOnError do - -- HACK: use async to fork on MonadResourceManager - -- This should use MonadAsync instead, but this implementation is a temporary compatability wrapper and the - -- constraints are based on the new design. - liftIO $ onResourceManager resourceManager $ async do + -- This implementation is a temporary compatability wrapper and forking isn't necessary with the new design. + forkTask do attachDisposable resourceManager =<< liftIO do unsafeAsyncObserveIO observable \msg -> do currentMsgId <- atomically do @@ -119,11 +116,7 @@ class IsRetrievable v o => IsObservable v o | o -> v where unsafeAsyncObserveIO :: o -> (ObservableMessage v -> IO ()) -> IO Disposable unsafeAsyncObserveIO observable callback = do - resourceManager <- unsafeNewResourceManager - onResourceManager resourceManager do - asyncObserve_ observable (liftIO . callback) - - pure (toDisposable resourceManager) + forkTask_ $ withOnResourceManager $ observe observable (liftIO . callback) toObservable :: o -> Observable v toObservable = Observable @@ -134,11 +127,8 @@ class IsRetrievable v o => IsObservable v o | o -> v where {-# MINIMAL observe | unsafeAsyncObserveIO #-} -asyncObserve :: IsObservable v o => MonadAsync m => o -> (ObservableMessage v -> m ()) -> m Disposable -asyncObserve observable callback = toDisposable <$> async (observe observable callback) - -asyncObserve_ :: IsObservable v o => MonadAsync m => o -> (ObservableMessage v -> m ()) -> m () -asyncObserve_ observable callback = async_ (observe observable callback) +asyncObserve :: IsObservable v o => MonadAsync m => o -> (ObservableMessage v -> m ()) -> m () +asyncObserve observable callback = async_ (observe observable callback) data ObserveWhileCompleted = ObserveWhileCompleted @@ -225,9 +215,9 @@ instance IsObservable v (MappedObservable v) where data BindObservable r = forall a. BindObservable (Observable a) (a -> Observable r) instance IsRetrievable r (BindObservable r) where - retrieve (BindObservable fx fn) = async $ do - x <- awaitResult $ retrieve fx - awaitResult $ retrieve $ fn x + retrieve (BindObservable fx fn) = do + x <- await =<< retrieve fx + retrieve $ fn x instance IsObservable r (BindObservable r) where unsafeAsyncObserveIO :: BindObservable r -> (ObservableMessage r -> IO ()) -> IO Disposable @@ -294,8 +284,7 @@ instance IsObservable r (BindObservable r) where data CatchObservable e r = Exception e => CatchObservable (Observable r) (e -> Observable r) instance IsRetrievable r (CatchObservable e r) where - retrieve (CatchObservable fx fn) = async $ - awaitResult (retrieve fx) `catch` \ex -> awaitResult (retrieve (fn ex)) + retrieve (CatchObservable fx fn) = retrieve fx `catch` \ex -> retrieve (fn ex) instance IsObservable r (CatchObservable e r) where unsafeAsyncObserveIO :: CatchObservable e r -> (ObservableMessage r -> IO ()) -> IO Disposable @@ -361,7 +350,7 @@ instance IsObservable r (CatchObservable e r) where newtype ObservableVar v = ObservableVar (MVar (v, HM.HashMap Unique (ObservableCallback v))) instance IsRetrievable v (ObservableVar v) where - retrieve (ObservableVar mvar) = liftIO $ successfulTask . fst <$> readMVar mvar + retrieve (ObservableVar mvar) = liftIO $ pure . fst <$> readMVar mvar instance IsObservable v (ObservableVar v) where unsafeAsyncObserveIO (ObservableVar mvar) callback = do key <- newUnique @@ -448,7 +437,7 @@ mergeObservable :: (IsObservable v0 o0, IsObservable v1 o1) => (v0 -> v1 -> r) - mergeObservable merge x y = Observable $ MergedObservable merge x y data FnObservable v = FnObservable { - retrieveFn :: forall m. MonadAsync m => m (Task v), + retrieveFn :: forall m. (MonadResourceManager m, MonadAwait m) => m (Awaitable v), observeFn :: (ObservableMessage v -> IO ()) -> IO Disposable } instance IsRetrievable v (FnObservable v) where @@ -463,7 +452,7 @@ instance IsObservable v (FnObservable v) where -- | Implement an Observable by directly providing functions for `retrieve` and `subscribe`. fnObservable :: ((ObservableMessage v -> IO ()) -> IO Disposable) - -> (forall m. MonadAsync m => m (Task v)) + -> (forall m. (MonadResourceManager m, MonadAwait m) => m (Awaitable v)) -> Observable v fnObservable observeFn retrieveFn = toObservable FnObservable{observeFn, retrieveFn} @@ -474,8 +463,8 @@ synchronousFnObservable -> Observable v synchronousFnObservable observeFn synchronousRetrieveFn = fnObservable observeFn retrieveFn where - retrieveFn :: (forall m. MonadAsync m => m (Task v)) - retrieveFn = liftIO $ successfulTask <$> synchronousRetrieveFn + retrieveFn :: (forall m. (MonadResourceManager m, MonadAwait m) => m (Awaitable v)) + retrieveFn = liftIO $ pure <$> synchronousRetrieveFn newtype ConstObservable v = ConstObservable v diff --git a/src/Quasar/Timer.hs b/src/Quasar/Timer.hs index 76eeacd..76511d5 100644 --- a/src/Quasar/Timer.hs +++ b/src/Quasar/Timer.hs @@ -125,7 +125,7 @@ startSchedulerThread scheduler = do wait :: Timer -> Int -> IO () wait nextTimer microseconds = do - delay <- toAwaitable <$> newDelay resourceManager' microseconds + delay <- onResourceManager resourceManager' $ toAwaitable <$> newDelay microseconds awaitAny2 delay nextTimerChanged where nextTimerChanged :: Awaitable () @@ -189,8 +189,13 @@ newtype Delay = Delay (Task ()) instance IsAwaitable () Delay where toAwaitable (Delay task) = toAwaitable task `catch` \TaskDisposed -> throwM TimerCancelled -newDelay :: ResourceManager -> Int -> IO Delay -newDelay resourceManager microseconds = onResourceManager resourceManager $ Delay <$> async (liftIO (threadDelay microseconds)) +newDelay :: MonadResourceManager m => Int -> m Delay +newDelay microseconds = do + resourceManager <- askResourceManager + mask_ do + delay <- Delay <$> forkTask (liftIO (threadDelay microseconds)) + attachDisposable resourceManager delay + pure delay diff --git a/test/Quasar/AsyncSpec.hs b/test/Quasar/AsyncSpec.hs index 8034d9f..15f0d2a 100644 --- a/test/Quasar/AsyncSpec.hs +++ b/test/Quasar/AsyncSpec.hs @@ -14,10 +14,10 @@ spec :: Spec spec = parallel $ do describe "async" $ do it "can pass a value through async and await" $ do - withOnResourceManager (await =<< async (pure 42)) `shouldReturn` (42 :: Int) + withOnResourceManager (runUnlimitedAsync (await =<< async (pure 42))) `shouldReturn` (42 :: Int) it "can pass a value through async and await" $ do - withOnResourceManager (await =<< async (liftIO (threadDelay 100000) >> pure 42)) `shouldReturn` (42 :: Int) + withOnResourceManager (runUnlimitedAsync (await =<< async (liftIO (threadDelay 100000) >> pure 42))) `shouldReturn` (42 :: Int) describe "await" $ do it "can await the result of an async that is completed later" $ do -- GitLab