From 170e3e80a9d522ab3003f7d55b59c27e6b49bfaa Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Tue, 31 Aug 2021 16:36:40 +0200
Subject: [PATCH] Rework async behavior

Remove implicit MonadAsync constraint on `ReaderT ResourceManager IO`
to prevent accidental forking in a resource-limited context.

Change return type to `Awaitable` to match new MonadResourceManager
behavior.

Add `runUnlimitedAsync` to run a forking MonadAsync.

Add `forkTask` functions for explicit forking in any context.

Co-authored-by: Jan Beinke <git@janbeinke.com>
---
 src/Quasar/Async.hs      | 142 ++++++++++++++++++++++++++++-----------
 src/Quasar/Disposable.hs |   6 ++
 src/Quasar/Observable.hs |  41 +++++------
 src/Quasar/Timer.hs      |  11 ++-
 test/Quasar/AsyncSpec.hs |   4 +-
 5 files changed, 133 insertions(+), 71 deletions(-)

diff --git a/src/Quasar/Async.hs b/src/Quasar/Async.hs
index 6fdc536..4c013ac 100644
--- a/src/Quasar/Async.hs
+++ b/src/Quasar/Async.hs
@@ -1,6 +1,7 @@
 module Quasar.Async (
   -- * Async/await
   MonadAsync(..),
+  runUnlimitedAsync,
   async_,
   asyncWithUnmask_,
 
@@ -16,6 +17,12 @@ module Quasar.Async (
   -- ** Task exceptions
   CancelTask(..),
   TaskDisposed(..),
+
+  -- * Unmanaged forking
+  forkTask,
+  forkTask_,
+  forkTaskWithUnmask,
+  forkTaskWithUnmask_,
 ) where
 
 import Control.Concurrent (ThreadId, forkIOWithUnmask, throwTo)
@@ -27,66 +34,121 @@ import Quasar.Disposable
 import Quasar.Prelude
 
 
-class (MonadAwait m, MonadResourceManager m, MonadMask m) => MonadAsync m where
-  async :: m r -> m (Task r)
+class (MonadAwait m, MonadResourceManager m) => MonadAsync m where
+  async :: m r -> m (Awaitable r)
   async action = asyncWithUnmask ($ action)
 
   -- | TODO: Documentation
   --
   -- The action will be run with asynchronous exceptions masked and will be passed an action that can be used unmask.
-  asyncWithUnmask :: ((forall a. m a -> m a) -> m r) -> m (Task r)
+  --
+  -- TODO change signature to `Awaitable`
+  asyncWithUnmask :: ((forall a. m a -> m a) -> m r) -> m (Awaitable r)
+
+
+instance MonadAsync m => MonadAsync (ReaderT r m) where
+  asyncWithUnmask :: MonadAsync m => ((forall b. ReaderT r m b -> ReaderT r m b) -> ReaderT r m a) -> ReaderT r m (Awaitable a)
+  asyncWithUnmask action = do
+    x <- ask
+    lift $ asyncWithUnmask \unmask -> runReaderT (action (liftUnmask unmask)) x
+    where
+      -- | Lift an "unmask" action (e.g. from `mask`) into a `ReaderT`.
+      liftUnmask :: (m a -> m a) -> (ReaderT r m) a -> (ReaderT r m) a
+      liftUnmask unmask action = do
+        value <- ask
+        lift $ unmask $ runReaderT action value
+
+
+async_ :: MonadAsync m => m () -> m ()
+async_ = void . async
+
+asyncWithUnmask_ :: MonadAsync m => ((forall a. m a -> m a) -> m ()) -> m ()
+asyncWithUnmask_ action = void $ asyncWithUnmask action
+
+
 
-instance MonadAsync (ReaderT ResourceManager IO) where
+newtype UnlimitedAsync r = UnlimitedAsync { unUnlimitedAsync :: (ReaderT ResourceManager IO r) }
+  deriving newtype (
+    Functor,
+    Applicative,
+    Monad,
+    MonadIO,
+    MonadThrow,
+    MonadCatch,
+    MonadMask,
+    MonadFail,
+    Alternative,
+    MonadPlus,
+    MonadAwait,
+    MonadResourceManager
+  )
+
+instance MonadAsync UnlimitedAsync where
   asyncWithUnmask action = do
     resourceManager <- askResourceManager
+    liftIO $ mask_ $ do
+      task <- forkTaskWithUnmask (\unmask -> runReaderT (unUnlimitedAsync (action (liftUnmask unmask))) resourceManager)
+      attachDisposable resourceManager task
+      pure $ toAwaitable task
+    where
+      liftUnmask :: (forall b. IO b -> IO b) -> UnlimitedAsync a -> UnlimitedAsync a
+      liftUnmask unmask (UnlimitedAsync action) = UnlimitedAsync do
+        resourceManager <- ask
+        liftIO $ unmask $ runReaderT action resourceManager
 
-    liftIO $ mask_ do
-      resultVar <- newAsyncVar
-      threadIdVar <- newEmptyTMVarIO
 
-      disposable <- attachDisposeAction resourceManager (disposeTask threadIdVar resultVar)
+runUnlimitedAsync :: (MonadResourceManager m) => (forall f. MonadAsync f => f r) -> m r
+runUnlimitedAsync action = do
+  resourceManager <- askResourceManager
+  liftIO $ runReaderT (unUnlimitedAsync action) resourceManager
 
-      onException
-        do
-          atomically . putTMVar threadIdVar . Just =<<
-            forkIOWithUnmask \unmask -> do
-              result <- try $ catch
-                do runReaderT (action (liftUnmask unmask)) resourceManager
-                \CancelTask -> throwIO TaskDisposed
 
-              putAsyncVarEither_ resultVar result
 
-              -- Thread has completed work, "disarm" the disposable and fire it
-              void $ atomically $ swapTMVar threadIdVar Nothing
-              disposeAndAwait disposable
+forkTask :: MonadIO m => IO a -> m (Task a)
+forkTask action = forkTaskWithUnmask ($ action)
 
-        do atomically $ putTMVar threadIdVar Nothing
+forkTask_ :: MonadIO m => IO () -> m Disposable
+forkTask_ action = toDisposable <$> forkTask action
 
-      pure $ Task disposable (toAwaitable resultVar)
-    where
-      disposeTask :: TMVar (Maybe ThreadId) -> AsyncVar r -> IO (Awaitable ())
-      disposeTask threadIdVar resultVar = mask_ do
-        -- Blocks until the thread is forked
-        atomically (swapTMVar threadIdVar Nothing) >>= \case
-          -- Thread completed or initialization failed
-          Nothing -> pure ()
-          Just threadId -> throwTo threadId CancelTask
+forkTaskWithUnmask :: MonadIO m => ((forall b. IO b -> IO b) -> IO a) -> m (Task a)
+forkTaskWithUnmask action = do
+  liftIO $ mask_ do
+    resultVar <- newAsyncVar
+    threadIdVar <- newEmptyTMVarIO
 
-        -- Wait for task completion or failure. Tasks must not ignore `CancelTask` or this will hang.
-        pure $ void (toAwaitable resultVar) `catchAll` const (pure ())
+    disposable <- newDisposable $ disposeTask threadIdVar resultVar
 
--- | Lift an "unmask" action (e.g. from `mask`) into a `ReaderT`.
-liftUnmask :: (IO a -> IO a) -> (ReaderT r IO) a -> (ReaderT r IO) a
-liftUnmask unmask action = do
-  value <- ask
-  liftIO $ unmask $ runReaderT action value
+    onException
+      do
+        atomically . putTMVar threadIdVar . Just =<<
+          forkIOWithUnmask \unmask -> do
+            result <- try $ catch
+              do action unmask
+              \CancelTask -> throwIO TaskDisposed
 
+            putAsyncVarEither_ resultVar result
 
-async_ :: MonadAsync m => m () -> m ()
-async_ = void . async
+            -- Thread has completed work, "disarm" the disposable and fire it
+            void $ atomically $ swapTMVar threadIdVar Nothing
+            disposeAndAwait disposable
 
-asyncWithUnmask_ :: MonadAsync m => ((forall a. m a -> m a) -> m ()) -> m ()
-asyncWithUnmask_ action = void $ asyncWithUnmask action
+      do atomically $ putTMVar threadIdVar Nothing
+
+    pure $ Task disposable (toAwaitable resultVar)
+  where
+    disposeTask :: TMVar (Maybe ThreadId) -> AsyncVar r -> IO (Awaitable ())
+    disposeTask threadIdVar resultVar = mask_ do
+      -- Blocks until the thread is forked
+      atomically (swapTMVar threadIdVar Nothing) >>= \case
+        -- Thread completed or initialization failed
+        Nothing -> pure ()
+        Just threadId -> throwTo threadId CancelTask
+
+      -- Wait for task completion or failure. Tasks must not ignore `CancelTask` or this will hang.
+      pure $ void (toAwaitable resultVar) `catchAll` const (pure ())
+
+forkTaskWithUnmask_ :: MonadIO m => ((forall b. IO b -> IO b) -> IO ()) -> m Disposable
+forkTaskWithUnmask_ action = toDisposable <$> forkTaskWithUnmask action
 
 
 
diff --git a/src/Quasar/Disposable.hs b/src/Quasar/Disposable.hs
index b6df3c7..1494c18 100644
--- a/src/Quasar/Disposable.hs
+++ b/src/Quasar/Disposable.hs
@@ -217,6 +217,12 @@ instance (MonadMask m, MonadIO m) => MonadResourceManager (ReaderT ResourceManag
   askResourceManager = ask
   localResourceManager = local . const
 
+instance {-# OVERLAPPABLE #-} MonadResourceManager m => MonadResourceManager (ReaderT r m) where
+  askResourceManager = lift askResourceManager
+  localResourceManager resourceManager action = do
+    x <- ask
+    lift $ localResourceManager resourceManager $ runReaderT action x
+
 
 onResourceManager :: (HasResourceManager a) => a -> ReaderT ResourceManager m r -> m r
 onResourceManager target action = runReaderT action (getResourceManager target)
diff --git a/src/Quasar/Observable.hs b/src/Quasar/Observable.hs
index a0da732..5b302a3 100644
--- a/src/Quasar/Observable.hs
+++ b/src/Quasar/Observable.hs
@@ -10,7 +10,6 @@ module Quasar.Observable (
   ObservableMessage(..),
   toObservableUpdate,
   asyncObserve,
-  asyncObserve_,
 
   -- * ObservableVar
   ObservableVar,
@@ -76,7 +75,7 @@ toObservableUpdate (ObservableNotAvailable ex) = throwM ex
 
 
 class IsRetrievable v a | a -> v where
-  retrieve :: MonadAsync m => a -> m (Task v)
+  retrieve :: (MonadResourceManager m, MonadAwait m) => a -> m (Awaitable v)
 
 retrieveIO :: IsRetrievable v a => a -> IO v
 retrieveIO x = withOnResourceManager $ await =<< retrieve x
@@ -94,10 +93,8 @@ class IsRetrievable v o => IsObservable v o | o -> v where
     resourceManager <- askResourceManager
     bracketOnError
       do
-        -- HACK: use async to fork on MonadResourceManager
-        -- This should use MonadAsync instead, but this implementation is a temporary compatability wrapper and the
-        -- constraints are based on the new design.
-        liftIO $ onResourceManager resourceManager $ async do
+        -- This implementation is a temporary compatability wrapper and forking isn't necessary with the new design.
+        forkTask do
           attachDisposable resourceManager =<< liftIO do
             unsafeAsyncObserveIO observable \msg -> do
               currentMsgId <- atomically do
@@ -119,11 +116,7 @@ class IsRetrievable v o => IsObservable v o | o -> v where
 
   unsafeAsyncObserveIO :: o -> (ObservableMessage v -> IO ()) -> IO Disposable
   unsafeAsyncObserveIO observable callback = do
-    resourceManager <- unsafeNewResourceManager
-    onResourceManager resourceManager do
-      asyncObserve_ observable (liftIO . callback)
-
-    pure (toDisposable resourceManager)
+    forkTask_ $ withOnResourceManager $ observe observable (liftIO . callback)
 
   toObservable :: o -> Observable v
   toObservable = Observable
@@ -134,11 +127,8 @@ class IsRetrievable v o => IsObservable v o | o -> v where
   {-# MINIMAL observe | unsafeAsyncObserveIO #-}
 
 
-asyncObserve :: IsObservable v o => MonadAsync m => o -> (ObservableMessage v -> m ()) -> m Disposable
-asyncObserve observable callback = toDisposable <$> async (observe observable callback)
-
-asyncObserve_ :: IsObservable v o => MonadAsync m => o -> (ObservableMessage v -> m ()) -> m ()
-asyncObserve_ observable callback = async_ (observe observable callback)
+asyncObserve :: IsObservable v o => MonadAsync m => o -> (ObservableMessage v -> m ()) -> m ()
+asyncObserve observable callback = async_ (observe observable callback)
 
 
 data ObserveWhileCompleted = ObserveWhileCompleted
@@ -225,9 +215,9 @@ instance IsObservable v (MappedObservable v) where
 data BindObservable r = forall a. BindObservable (Observable a) (a -> Observable r)
 
 instance IsRetrievable r (BindObservable r) where
-  retrieve (BindObservable fx fn) = async $ do
-    x <- awaitResult $ retrieve fx
-    awaitResult $ retrieve $ fn x
+  retrieve (BindObservable fx fn) = do
+    x <- await =<< retrieve fx
+    retrieve $ fn x
 
 instance IsObservable r (BindObservable r) where
   unsafeAsyncObserveIO :: BindObservable r -> (ObservableMessage r -> IO ()) -> IO Disposable
@@ -294,8 +284,7 @@ instance IsObservable r (BindObservable r) where
 data CatchObservable e r = Exception e => CatchObservable (Observable r) (e -> Observable r)
 
 instance IsRetrievable r (CatchObservable e r) where
-  retrieve (CatchObservable fx fn) = async $
-    awaitResult (retrieve fx) `catch` \ex -> awaitResult (retrieve (fn ex))
+  retrieve (CatchObservable fx fn) = retrieve fx `catch` \ex -> retrieve (fn ex)
 
 instance IsObservable r (CatchObservable e r) where
   unsafeAsyncObserveIO :: CatchObservable e r -> (ObservableMessage r -> IO ()) -> IO Disposable
@@ -361,7 +350,7 @@ instance IsObservable r (CatchObservable e r) where
 
 newtype ObservableVar v = ObservableVar (MVar (v, HM.HashMap Unique (ObservableCallback v)))
 instance IsRetrievable v (ObservableVar v) where
-  retrieve (ObservableVar mvar) = liftIO $ successfulTask . fst <$> readMVar mvar
+  retrieve (ObservableVar mvar) = liftIO $ pure . fst <$> readMVar mvar
 instance IsObservable v (ObservableVar v) where
   unsafeAsyncObserveIO (ObservableVar mvar) callback = do
     key <- newUnique
@@ -448,7 +437,7 @@ mergeObservable :: (IsObservable v0 o0, IsObservable v1 o1) => (v0 -> v1 -> r) -
 mergeObservable merge x y = Observable $ MergedObservable merge x y
 
 data FnObservable v = FnObservable {
-  retrieveFn :: forall m. MonadAsync m => m (Task v),
+  retrieveFn :: forall m. (MonadResourceManager m, MonadAwait m) => m (Awaitable v),
   observeFn :: (ObservableMessage v -> IO ()) -> IO Disposable
 }
 instance IsRetrievable v (FnObservable v) where
@@ -463,7 +452,7 @@ instance IsObservable v (FnObservable v) where
 -- | Implement an Observable by directly providing functions for `retrieve` and `subscribe`.
 fnObservable
   :: ((ObservableMessage v -> IO ()) -> IO Disposable)
-  -> (forall m. MonadAsync m => m (Task v))
+  -> (forall m. (MonadResourceManager m, MonadAwait m) => m (Awaitable v))
   -> Observable v
 fnObservable observeFn retrieveFn = toObservable FnObservable{observeFn, retrieveFn}
 
@@ -474,8 +463,8 @@ synchronousFnObservable
   -> Observable v
 synchronousFnObservable observeFn synchronousRetrieveFn = fnObservable observeFn retrieveFn
   where
-    retrieveFn :: (forall m. MonadAsync m => m (Task v))
-    retrieveFn = liftIO $ successfulTask <$> synchronousRetrieveFn
+    retrieveFn :: (forall m. (MonadResourceManager m, MonadAwait m) => m (Awaitable v))
+    retrieveFn = liftIO $ pure <$> synchronousRetrieveFn
 
 
 newtype ConstObservable v = ConstObservable v
diff --git a/src/Quasar/Timer.hs b/src/Quasar/Timer.hs
index 76eeacd..76511d5 100644
--- a/src/Quasar/Timer.hs
+++ b/src/Quasar/Timer.hs
@@ -125,7 +125,7 @@ startSchedulerThread scheduler = do
 
     wait :: Timer -> Int -> IO ()
     wait nextTimer microseconds = do
-      delay <- toAwaitable <$> newDelay resourceManager' microseconds
+      delay <- onResourceManager resourceManager' $ toAwaitable <$> newDelay microseconds
       awaitAny2 delay nextTimerChanged
       where
         nextTimerChanged :: Awaitable ()
@@ -189,8 +189,13 @@ newtype Delay = Delay (Task ())
 instance IsAwaitable () Delay where
   toAwaitable (Delay task) = toAwaitable task `catch` \TaskDisposed -> throwM TimerCancelled
 
-newDelay :: ResourceManager -> Int -> IO Delay
-newDelay resourceManager microseconds = onResourceManager resourceManager $ Delay <$> async (liftIO (threadDelay microseconds))
+newDelay :: MonadResourceManager m => Int -> m Delay
+newDelay microseconds = do
+  resourceManager <- askResourceManager
+  mask_ do
+    delay <- Delay <$> forkTask (liftIO (threadDelay microseconds))
+    attachDisposable resourceManager delay
+    pure delay
 
 
 
diff --git a/test/Quasar/AsyncSpec.hs b/test/Quasar/AsyncSpec.hs
index 8034d9f..15f0d2a 100644
--- a/test/Quasar/AsyncSpec.hs
+++ b/test/Quasar/AsyncSpec.hs
@@ -14,10 +14,10 @@ spec :: Spec
 spec = parallel $ do
   describe "async" $ do
     it "can pass a value through async and await" $ do
-      withOnResourceManager (await =<< async (pure 42)) `shouldReturn` (42 :: Int)
+      withOnResourceManager (runUnlimitedAsync (await =<< async (pure 42))) `shouldReturn` (42 :: Int)
 
     it "can pass a value through async and await" $ do
-      withOnResourceManager (await =<< async (liftIO (threadDelay 100000) >> pure 42)) `shouldReturn` (42 :: Int)
+      withOnResourceManager (runUnlimitedAsync (await =<< async (liftIO (threadDelay 100000) >> pure 42))) `shouldReturn` (42 :: Int)
 
   describe "await" $ do
     it "can await the result of an async that is completed later" $ do
-- 
GitLab