From df8e96e45ddfdaf988e7083613b0d62620db0711 Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Mon, 21 Feb 2022 18:11:37 +0100
Subject: [PATCH] Improve async implementation

---
 src/Quasar/Async/V2.hs      | 97 ++++++++++++++++++++-----------------
 src/Quasar/Monad.hs         |  7 ++-
 src/Quasar/Utils/ShortIO.hs |  2 +-
 3 files changed, 60 insertions(+), 46 deletions(-)

diff --git a/src/Quasar/Async/V2.hs b/src/Quasar/Async/V2.hs
index 6b056f7..0a10cc1 100644
--- a/src/Quasar/Async/V2.hs
+++ b/src/Quasar/Async/V2.hs
@@ -11,17 +11,21 @@ module Quasar.Async.V2 (
   AsyncException(..),
   isCancelAsync,
   isAsyncDisposed,
+
+  -- ** IO variant
+  async',
+  asyncWithUnmask',
 ) where
 
 import Control.Concurrent (ThreadId)
 import Control.Concurrent.STM
 import Control.Monad.Catch
 import Quasar.Async.Fork
-import Quasar.Async.STMHelper
 import Quasar.Awaitable
 import Quasar.Exceptions
 import Quasar.Monad
 import Quasar.Prelude
+import Quasar.Resources
 import Quasar.Resources.Disposer
 import Quasar.Utils.ShortIO
 import Control.Monad.Reader
@@ -36,20 +40,52 @@ instance IsAwaitable a (Async a) where
   toAwaitable (Async awaitable _) = awaitable
 
 
-unmanagedAsyncSTM :: IO a -> TIOWorker -> ExceptionChannel -> STM (Async a)
-unmanagedAsyncSTM fn = unmanagedAsyncWithUnmaskSTM (\unmask -> unmask fn)
+async :: MonadQuasar m => QuasarIO a -> m (Async a)
+async fn = asyncWithUnmask ($ fn)
+
+async_ :: MonadQuasar m => QuasarIO () -> m ()
+async_ fn = void $ asyncWithUnmask ($ fn)
+
+asyncWithUnmask :: MonadQuasar m => ((forall b. QuasarIO b -> QuasarIO b) -> QuasarIO a) -> m (Async a)
+asyncWithUnmask fn = do
+  quasar <- askQuasar
+  asyncWithUnmask' (\unmask -> runReaderT (fn (liftUnmask unmask)) quasar)
+  where
+    liftUnmask :: (forall b. IO b -> IO b) -> QuasarIO a -> QuasarIO a
+    liftUnmask unmask innerAction = do
+      quasar <- askQuasar
+      liftIO $ unmask $ runReaderT innerAction quasar
+
+asyncWithUnmask_ :: MonadQuasar m => ((forall b. QuasarIO b -> QuasarIO b) -> QuasarIO ()) -> m ()
+asyncWithUnmask_ fn = void $ asyncWithUnmask fn
+
+
+async' :: MonadQuasar m => IO a -> m (Async a)
+async' fn = asyncWithUnmask' ($ fn)
+
+asyncWithUnmask' :: forall a m. MonadQuasar m => ((forall b. IO b -> IO b) -> IO a) -> m (Async a)
+asyncWithUnmask' fn = maskIfRequired do
+  worker <- askIOWorker
+  exChan <- askExceptionChannel
+
+  (key, resultVar, threadIdVar, disposer) <- ensureSTM do
+    key <- newUniqueSTM
+    resultVar <- newAsyncVarSTM
+    threadIdVar <- newAsyncVarSTM
+    -- Disposer is created first to ensure the resource can be safely attached
+    disposer <- newPrimitiveDisposer (disposeFn key resultVar (toAwaitable threadIdVar)) worker exChan
+    pure (key, resultVar, threadIdVar, disposer)
+
+  registerResource disposer
+
+  startShortIO_ do
+    threadId <- forkWithUnmaskShortIO (runAndPut exChan key resultVar disposer) exChan
+    putAsyncVarShortIO_ threadIdVar threadId
 
-unmanagedAsyncWithUnmaskSTM :: forall a. ((forall b. IO b -> IO b) -> IO a) -> TIOWorker -> ExceptionChannel -> STM (Async a)
-unmanagedAsyncWithUnmaskSTM fn worker exChan = do
-  key <- newUniqueSTM
-  resultVar <- newAsyncVarSTM
-  disposer <- mfix \disposer -> do
-    tidAwaitable <- forkWithUnmaskSTM (runAndPut key resultVar disposer) worker exChan
-    newPrimitiveDisposer (disposeFn key resultVar tidAwaitable) worker exChan
   pure $ Async (toAwaitable resultVar) disposer
   where
-    runAndPut :: Unique -> AsyncVar a -> Disposer -> (forall b. IO b -> IO b) -> IO ()
-    runAndPut key resultVar disposer unmask = do
+    runAndPut :: ExceptionChannel -> Unique -> AsyncVar a -> Disposer -> (forall b. IO b -> IO b) -> IO ()
+    runAndPut exChan key resultVar disposer unmask = do
       -- Called in masked state by `forkWithUnmask`
       result <- try $ fn unmask
       case result of
@@ -64,36 +100,9 @@ unmanagedAsyncWithUnmaskSTM fn worker exChan = do
           putAsyncVar_ resultVar retVal
           atomically $ disposeEventuallySTM_ disposer
     disposeFn :: Unique -> AsyncVar a -> Awaitable ThreadId -> ShortIO (Awaitable ())
-    disposeFn key resultVar tidAwaitable = do
-      -- Awaits forking of the thread, which should happen immediately (as long as the TIOWorker-invariant isn't broken elsewhere)
-      tid <- unsafeShortIO $ await tidAwaitable
-      -- `throwTo` should also happen immediately, as long as `uninterruptibleMask` isn't abused elsewhere
-      throwToShortIO tid (CancelAsync key)
+    disposeFn key resultVar threadIdAwaitable = do
+      -- Should not block or fail (unless the TIOWorker is broken)
+      threadId <- unsafeShortIO $ await threadIdAwaitable
+      throwToShortIO threadId (CancelAsync key)
       -- Considered complete once a result (i.e. success or failure) has been stored
-      pure (() <$ toAwaitable resultVar)
-
-
-async :: MonadQuasar m => QuasarIO a -> m (Async a)
-async fn = asyncWithUnmask ($ fn)
-
-async_ :: MonadQuasar m => QuasarIO () -> m ()
-async_ fn = void $ asyncWithUnmask ($ fn)
-
-asyncWithUnmask :: MonadQuasar m => ((forall b. QuasarIO b -> QuasarIO b) -> QuasarIO a) -> m (Async a)
-asyncWithUnmask fn = do
-  quasar <- askQuasar
-  worker <- askIOWorker
-  exChan <- askExceptionChannel
-  rm <- askResourceManager
-  ensureSTM do
-    as <- unmanagedAsyncWithUnmaskSTM (\unmask -> runReaderT (fn (liftUnmask unmask)) quasar) worker exChan
-    attachResource rm as
-    pure as
-  where
-    liftUnmask :: (forall b. IO b -> IO b) -> QuasarIO a -> QuasarIO a
-    liftUnmask unmask innerAction = do
-      quasar <- askQuasar
-      liftIO $ unmask $ runReaderT innerAction quasar
-
-asyncWithUnmask_ :: MonadQuasar m => ((forall b. QuasarIO b -> QuasarIO b) -> QuasarIO ()) -> m ()
-asyncWithUnmask_ fn = void $ asyncWithUnmask fn
+      pure (awaitSuccessOrFailure resultVar)
diff --git a/src/Quasar/Monad.hs b/src/Quasar/Monad.hs
index 1fb80f8..2894462 100644
--- a/src/Quasar/Monad.hs
+++ b/src/Quasar/Monad.hs
@@ -17,6 +17,8 @@ module Quasar.Monad (
 
   enterQuasarIO,
   enterQuasarSTM,
+
+  startShortIO_,
 ) where
 
 import Control.Concurrent.STM
@@ -94,7 +96,7 @@ instance (MonadIO m, MonadMask m, MonadFix m) => MonadQuasar (QuasarT m) where
   maskIfRequired = mask_
   startShortIO fn = do
     exChan <- askExceptionChannel
-    liftIO $ try (runShortIO fn) >>= \case
+    liftIO $ uninterruptibleMask_ $ try (runShortIO fn) >>= \case
       Left ex -> do
         atomically $ throwToExceptionChannel exChan ex
         pure $ throwM $ toException $ AsyncException ex
@@ -133,6 +135,9 @@ instance {-# OVERLAPPABLE #-} MonadQuasar m => MonadQuasar (ReaderT r m) where
 -- TODO MonadQuasar instances for StateT, WriterT, RWST, MaybeT, ...
 
 
+startShortIO_ :: MonadQuasar m => ShortIO () -> m ()
+startShortIO_ fn = void $ startShortIO fn
+
 askIOWorker :: MonadQuasar m => m TIOWorker
 askIOWorker = quasarIOWorker <$> askQuasar
 
diff --git a/src/Quasar/Utils/ShortIO.hs b/src/Quasar/Utils/ShortIO.hs
index 294dbbd..1e5fb1b 100644
--- a/src/Quasar/Utils/ShortIO.hs
+++ b/src/Quasar/Utils/ShortIO.hs
@@ -20,7 +20,7 @@ import Quasar.Prelude
 import Control.Concurrent
 
 newtype ShortIO a = ShortIO (IO a)
-  deriving newtype (Functor, Applicative, Monad, MonadThrow, MonadCatch, MonadMask)
+  deriving newtype (Functor, Applicative, Monad, MonadThrow, MonadCatch, MonadMask, MonadFix)
 
 runShortIO :: ShortIO a -> IO a
 runShortIO (ShortIO fn) = fn
-- 
GitLab