From 79e0acd3571cdc30ac8bdbdc5747671a117f91d4 Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Sun, 25 Jul 2021 01:58:51 +0200
Subject: [PATCH] Implement AsyncIO cancellation

---
 src/Quasar/Core.hs       | 250 +++++++++++++++++++++++++++++----------
 test/Quasar/AsyncSpec.hs |   6 +
 2 files changed, 191 insertions(+), 65 deletions(-)

diff --git a/src/Quasar/Core.hs b/src/Quasar/Core.hs
index 17826fa..81eed95 100644
--- a/src/Quasar/Core.hs
+++ b/src/Quasar/Core.hs
@@ -23,11 +23,13 @@ module Quasar.Core (
 
   -- * Disposable
   IsDisposable(..),
-  disposeIO,
   Disposable,
   mkDisposable,
   synchronousDisposable,
   noDisposable,
+
+  -- * Cancellation
+  withCancellationToken,
 ) where
 
 import Control.Concurrent (forkIOWithUnmask)
@@ -35,6 +37,7 @@ import Control.Exception (MaskingState(..), getMaskingState)
 import Control.Monad.Catch
 import Data.HashMap.Strict qualified as HM
 import Data.Maybe (isJust)
+import Data.Void (absurd)
 import Quasar.Prelude
 
 -- * Async
@@ -42,22 +45,17 @@ import Quasar.Prelude
 class IsAsync r a | a -> r where
   -- | Wait until the promise is settled and return the result.
   wait :: a -> IO r
-  wait x = do
-    mvar <- newEmptyMVar
-    onResult_ x (void . tryPutMVar mvar . Left) (resultCallback mvar)
-    readMVar mvar >>= either throwIO pure
-    where
-      resultCallback :: MVar (Either SomeException r) -> Either SomeException r -> IO ()
-      resultCallback mvar result = do
-        success <- tryPutMVar mvar result
-        unless success $ fail "Callback was called multiple times"
+  wait = wait . toAsync
 
   peekAsync :: a -> IO (Maybe (Either SomeException r))
+  peekAsync = peekAsync . toAsync
 
   -- | Register a callback, that will be called once the promise is settled.
   -- If the promise is already settled, the callback will be called immediately instead.
   --
   -- The returned `Disposable` can be used to deregister the callback.
+  --
+  -- 'onResult' should not throw.
   onResult
     :: a
     -- ^ async
@@ -65,18 +63,21 @@ class IsAsync r a | a -> r where
     -- ^ callback exception handler
     -> (Either SomeException r -> IO ())
     -- ^ callback
-    -> IO Disposable
+    -> IO CallbackDisposable
+  onResult x ceh c = onResult (toAsync x) ceh c
 
   onResult_
     :: a
     -> (SomeException -> IO ())
     -> (Either SomeException r -> IO ())
     -> IO ()
-  onResult_ x y = void . onResult x y
+  onResult_ x ceh c = onResult_ (toAsync x) ceh c
 
   toAsync :: a -> Async r
   toAsync = SomeAsync
 
+  {-# MINIMAL toAsync | (wait, peekAsync, onResult, onResult_) #-}
+
 
 data Async r = forall a. IsAsync r a => SomeAsync a
 
@@ -86,12 +87,18 @@ instance IsAsync r (Async r) where
   onResult_ (SomeAsync x) y = onResult_ x y
   peekAsync (SomeAsync x) = peekAsync x
 
+instance Functor Async where
+  fmap fn = toAsync . MappedAsync fn
+
 
 
 newtype CompletedAsync r = CompletedAsync (Either SomeException r)
 instance IsAsync r (CompletedAsync r) where
   wait (CompletedAsync value) = either throwIO pure value
-  onResult (CompletedAsync value) callbackExceptionHandler callback = noDisposable <$ (callback value `catch` callbackExceptionHandler)
+  onResult (CompletedAsync value) callbackExceptionHandler callback =
+    noCallbackDisposable <$ (callback value `catch` callbackExceptionHandler)
+  onResult_ (CompletedAsync value) callbackExceptionHandler callback =
+    callback value `catch` callbackExceptionHandler
   peekAsync (CompletedAsync value) = pure $ Just value
 
 completedAsync :: Either SomeException r -> Async r
@@ -104,6 +111,14 @@ failedAsync :: SomeException -> Async r
 failedAsync = completedAsync . Left
 
 
+data MappedAsync r = forall a. MappedAsync (a -> r) (Async a)
+instance IsAsync r (MappedAsync r) where
+  wait (MappedAsync fn x) = fn <$> wait x
+  peekAsync (MappedAsync fn x) = fmap fn <<$>> peekAsync x
+  onResult (MappedAsync fn x) callbackExceptionHandler callback = onResult x callbackExceptionHandler $ callback . fmap fn
+  onResult_ (MappedAsync fn x) callbackExceptionHandler callback = onResult_ x callbackExceptionHandler $ callback . fmap fn
+
+
 -- * AsyncIO
 
 data AsyncIO r
@@ -111,10 +126,14 @@ data AsyncIO r
   | AsyncIOFailure SomeException
   | AsyncIOIO (IO r)
   | AsyncIOAsync (Async r)
-  | AsyncIOPlumbing (IO (AsyncIO r))
+  | AsyncIOPlumbing (MaskingState -> CancellationToken -> IO (AsyncIO r))
 
 instance Functor AsyncIO where
-  fmap fn = (>>= pure . fn)
+  fmap fn (AsyncIOSuccess x) = AsyncIOSuccess (fn x)
+  fmap _ (AsyncIOFailure x) = AsyncIOFailure x
+  fmap fn (AsyncIOIO x) = AsyncIOIO (fn <$> x)
+  fmap fn (AsyncIOAsync x) = AsyncIOAsync (fn <$> x)
+  fmap fn (AsyncIOPlumbing x) = mapPlumbing x (fmap (fmap fn))
 
 instance Applicative AsyncIO where
   pure = AsyncIOSuccess
@@ -125,9 +144,11 @@ instance Monad AsyncIO where
   (>>=) :: forall a b. AsyncIO a -> (a -> AsyncIO b) -> AsyncIO b
   (>>=) (AsyncIOSuccess x) fn = fn x
   (>>=) (AsyncIOFailure x) _ = AsyncIOFailure x
-  (>>=) (AsyncIOIO x) fn = AsyncIOPlumbing $ either AsyncIOFailure fn <$> try x
+  (>>=) (AsyncIOIO x) fn = AsyncIOPlumbing $ \maskingState cancellationToken -> do
+    -- TODO masking and cancellation
+    either AsyncIOFailure fn <$> try x
   (>>=) (AsyncIOAsync x) fn = bindAsync x fn
-  (>>=) (AsyncIOPlumbing x) fn = AsyncIOPlumbing $ (>>= fn) <$> x
+  (>>=) (AsyncIOPlumbing x) fn = mapPlumbing x (fmap (>>= fn))
 
 instance MonadIO AsyncIO where
   liftIO = AsyncIOIO
@@ -141,39 +162,48 @@ instance MonadCatch AsyncIO where
   catch x@(AsyncIOFailure ex) handler = maybe x handler (fromException ex)
   catch (AsyncIOIO x) handler = AsyncIOIO (try x) >>= handleEither handler
   catch (AsyncIOAsync x) handler = bindAsyncCatch x (handleEither handler)
-  catch (AsyncIOPlumbing x) handler = AsyncIOPlumbing $ (`catch` handler) <$> x
+  catch (AsyncIOPlumbing x) handler = mapPlumbing x (fmap (`catch` handler))
 
 handleEither :: Exception e => (e -> AsyncIO a) -> Either SomeException a -> AsyncIO a
 handleEither handler (Left ex) = maybe (AsyncIOFailure ex) handler (fromException ex)
 handleEither _ (Right r) = pure r
 
+mapPlumbing :: (MaskingState -> CancellationToken -> IO (AsyncIO a)) -> (IO (AsyncIO a) -> IO (AsyncIO b)) -> AsyncIO b
+mapPlumbing plumbing fn = AsyncIOPlumbing $ \maskingState cancellationToken -> fn (plumbing maskingState cancellationToken)
+
 bindAsync :: forall a b. Async a -> (a -> AsyncIO b) -> AsyncIO b
 bindAsync x fn = bindAsyncCatch x (either AsyncIOFailure fn)
 
 bindAsyncCatch :: forall a b. Async a -> (Either SomeException a -> AsyncIO b) -> AsyncIO b
-bindAsyncCatch x fn = AsyncIOPlumbing $ newAsyncVar >>= bindAsync'
+bindAsyncCatch x fn = AsyncIOPlumbing $ \maskingState cancellationToken -> do
+  var <- newAsyncVar
+  disposableMVar <- newEmptyMVar
+  go maskingState cancellationToken var disposableMVar
   where
-    bindAsync' resultVar = do
-      withResult x resultVar step
-      pure $ await resultVar
-    step :: (Either SomeException b -> IO ()) -> Either SomeException a -> IO ()
-    step put = putAsyncIOResult put . fn
-
-withResult :: Async a -> AsyncVar b -> ((Either SomeException b -> IO ()) -> Either SomeException a -> IO ()) -> IO ()
-withResult x var fn = onResult_ x (failAsyncVar var) (fn (putAsyncVarEither var))
-
-putAsyncIOResult :: (Either SomeException a -> IO ()) -> AsyncIO a -> IO ()
-putAsyncIOResult put (AsyncIOSuccess x) = put (Right x)
-putAsyncIOResult put (AsyncIOFailure x) = put (Left x)
-putAsyncIOResult put (AsyncIOIO x) = try x >>= put
-putAsyncIOResult put (AsyncIOAsync x) = onResult_ x (put . Left) put
-putAsyncIOResult put (AsyncIOPlumbing x) = x >>= putAsyncIOResult put
+    go maskingState cancellationToken var disposableMVar = do
+      disposable <- onResult x (failAsyncVar_ var) $ \x -> do
+        (putAsyncIOResult . fn) x
+      -- TODO update mvar and dispose when completed
+      putMVar disposableMVar disposable
+      pure $ awaitUnlessCancellationRequested cancellationToken var
+      where
+        put = putAsyncVarEither var
+        putAsyncIOResult :: AsyncIO b -> IO ()
+        putAsyncIOResult (AsyncIOSuccess x) = put (Right x)
+        putAsyncIOResult (AsyncIOFailure x) = put (Left x)
+        putAsyncIOResult (AsyncIOIO x) = try x >>= put
+        putAsyncIOResult (AsyncIOAsync x) = onResult_ x (put . Left) put
+        putAsyncIOResult (AsyncIOPlumbing x) = x maskingState cancellationToken >>= putAsyncIOResult
 
 
 
 -- | Run the synchronous part of an `AsyncIO` and then return an `Async` that can be used to wait for completion of the synchronous part.
 async :: AsyncIO r -> AsyncIO (Async r)
-async = fmap successfulAsync
+async (AsyncIOSuccess x) = pure $ successfulAsync x
+async (AsyncIOFailure x) = pure $ failedAsync x
+async (AsyncIOIO x) = liftIO $ either failedAsync successfulAsync <$> try x
+async (AsyncIOAsync x) = pure x -- TODO caching
+async (AsyncIOPlumbing x) = mapPlumbing x (fmap async)
 
 await :: IsAsync r a => a -> AsyncIO r
 await = AsyncIOAsync . toAsync
@@ -185,7 +215,8 @@ runAsyncIO (AsyncIOFailure x) = throwIO x
 runAsyncIO (AsyncIOIO x) = x
 runAsyncIO (AsyncIOAsync x) = wait x
 runAsyncIO (AsyncIOPlumbing x) = do
-  x >>= runAsyncIO
+  maskingState <- getMaskingState
+  withCancellationToken $ x maskingState >=> runAsyncIO
 
 awaitResult :: AsyncIO (Async r) -> AsyncIO r
 awaitResult = (await =<<)
@@ -211,62 +242,127 @@ mapAsync fn = async . fmap fn . await
 -- | The default implementation for a `Async` that can be fulfilled later.
 newtype AsyncVar r = AsyncVar (MVar (AsyncVarState r))
 data AsyncVarState r
-  = AsyncVarCompleted (Either SomeException r)
+  = AsyncVarCompleted (Either SomeException r) (IO ())
   | AsyncVarOpen (HM.HashMap Unique (Either SomeException r -> IO (), SomeException -> IO ()))
 
 instance IsAsync r (AsyncVar r) where
+  wait x = do
+    mvar <- newEmptyMVar
+    onResult_ x (void . tryPutMVar mvar . Left) (resultCallback mvar)
+    readMVar mvar >>= either throwIO pure
+    where
+      resultCallback :: MVar (Either SomeException r) -> Either SomeException r -> IO ()
+      resultCallback mvar result = do
+        success <- tryPutMVar mvar result
+        unless success $ fail "Callback was called multiple times"
+
   peekAsync :: AsyncVar r -> IO (Maybe (Either SomeException r))
   peekAsync (AsyncVar mvar) = readMVar mvar >>= pure . \case
-    AsyncVarCompleted x -> Just x
+    AsyncVarCompleted x _ -> Just x
     AsyncVarOpen _ -> Nothing
 
-  onResult :: AsyncVar r -> (SomeException -> IO ()) -> (Either SomeException r -> IO ()) -> IO Disposable
+  onResult :: AsyncVar r -> (SomeException -> IO ()) -> (Either SomeException r -> IO ()) -> IO CallbackDisposable
   onResult (AsyncVar mvar) callbackExceptionHandler callback =
     modifyMVar mvar $ \case
       AsyncVarOpen callbacks -> do
         key <- newUnique
-        pure (AsyncVarOpen (HM.insert key (callback, callbackExceptionHandler) callbacks), removeHandler key)
-      x@(AsyncVarCompleted value) -> (x, noDisposable) <$ callback value
+        pure (AsyncVarOpen (HM.insert key (callback, callbackExceptionHandler) callbacks), callbackDisposable key)
+      x@(AsyncVarCompleted value _) -> (x, noCallbackDisposable) <$ callback value `catch` callbackExceptionHandler
     where
-      removeHandler :: Unique -> Disposable
-      removeHandler key = synchronousDisposable $ modifyMVar_ mvar $ pure . \case
-        x@(AsyncVarCompleted _) -> x
-        AsyncVarOpen x -> AsyncVarOpen $ HM.delete key x
+      callbackDisposable :: Unique -> CallbackDisposable
+      callbackDisposable key = CallbackDisposable removeHandler removeHandlerEventually
+        where
+          removeHandler = do
+            waitForCallbacks <- modifyMVar mvar $ pure . \case
+              x@(AsyncVarCompleted _ waitForCallbacks) -> (x, waitForCallbacks)
+              AsyncVarOpen x -> (AsyncVarOpen (HM.delete key x), pure ())
+            -- Dispose should only return after the callback can't be called any longer
+            -- If the callbacks are already being dispatched, wait for them to complete to keep the guarantee
+            waitForCallbacks
+
+          removeHandlerEventually =
+            modifyMVar_ mvar $ pure . \case
+              x@(AsyncVarCompleted _ _) -> x
+              AsyncVarOpen x -> AsyncVarOpen $ HM.delete key x
+
+  onResult_ x y = void . onResult x y
+
+tryPutAsyncVarEither :: forall a m. MonadIO m => AsyncVar a -> Either SomeException a -> m Bool
+tryPutAsyncVarEither (AsyncVar mvar) value = liftIO $ do
+  action <- modifyMVar mvar $ \case
+    x@(AsyncVarCompleted _ waitForCallbacks) -> pure (x, False <$ waitForCallbacks)
+    AsyncVarOpen callbacksMap -> do
+      callbacksCompletedMVar <- newEmptyMVar
+      let waitForCallbacks = readMVar callbacksCompletedMVar
+          callbacks = HM.elems callbacksMap
+      pure (AsyncVarCompleted value waitForCallbacks, fireCallbacks callbacks callbacksCompletedMVar)
+
+  action
+
+  where
+    fireCallbacks :: [(Either SomeException a -> IO (), SomeException -> IO ())] -> MVar () -> IO Bool
+    fireCallbacks callbacks callbacksCompletedMVar = do
+      forM_ callbacks $ \(callback, callbackExceptionHandler) ->
+        callback value `catch` callbackExceptionHandler
+      putMVar callbacksCompletedMVar ()
+      pure True
 
 
 newAsyncVar :: MonadIO m => m (AsyncVar r)
 newAsyncVar = liftIO $ AsyncVar <$> newMVar (AsyncVarOpen HM.empty)
 
+
 putAsyncVar :: MonadIO m => AsyncVar a -> a -> m ()
 putAsyncVar var = putAsyncVarEither var . Right
 
-failAsyncVar :: MonadIO m => AsyncVar a -> SomeException -> m ()
-failAsyncVar var = putAsyncVarEither var . Left
+tryPutAsyncVar :: MonadIO m => AsyncVar a -> a -> m Bool
+tryPutAsyncVar var = tryPutAsyncVarEither var . Right
+
+tryPutAsyncVar_ :: MonadIO m => AsyncVar a -> a -> m ()
+tryPutAsyncVar_ var = void . tryPutAsyncVar var
+
+failAsyncVar :: MonadIO m => AsyncVar a -> SomeException -> m Bool
+failAsyncVar var = tryPutAsyncVarEither var . Left
+
+failAsyncVar_ :: MonadIO m => AsyncVar a -> SomeException -> m ()
+failAsyncVar_ var = void . failAsyncVar var
 
 putAsyncVarEither :: MonadIO m => AsyncVar a -> Either SomeException a -> m ()
-putAsyncVarEither (AsyncVar mvar) value = liftIO $ do
-  mask $ \restore -> do
-    takeMVar mvar >>= \case
-      x@(AsyncVarCompleted _) -> do
-        putMVar mvar x
-        fail "An AsyncVar can only be fulfilled once"
-      AsyncVarOpen callbacksMap -> do
-        let callbacks = HM.elems callbacksMap
-        -- NOTE disposing a callback while it is called is a deadlock
-        forM_ callbacks $ \(callback, callbackExceptionHandler) ->
-          restore (callback value) `catch` callbackExceptionHandler
-        putMVar mvar (AsyncVarCompleted value)
+putAsyncVarEither avar value = liftIO $ do
+  success <- tryPutAsyncVarEither avar value
+  unless success $ fail "An AsyncVar can only be fulfilled once"
+
+tryPutAsyncVarEither_ :: MonadIO m => AsyncVar a -> Either SomeException a -> m ()
+tryPutAsyncVarEither_ var = void . tryPutAsyncVarEither var
+
+
+-- * Awaiting multiple asyncs
+
+awaitEither :: (IsAsync ra a , IsAsync rb b) => a -> b -> AsyncIO (Either ra rb)
+awaitEither x y = AsyncIOPlumbing $ \_ _ -> AsyncIOAsync <$> awaitEitherPlumbing x y
+
+awaitEitherPlumbing :: (IsAsync ra a , IsAsync rb b) => a -> b -> IO (Async (Either ra rb))
+awaitEitherPlumbing x y = do
+  var <- newAsyncVar
+  d1 <- onResult x (failAsyncVar_ var) (tryPutAsyncVarEither_ var . fmap Left)
+  d2 <- onResult y (failAsyncVar_ var) (tryPutAsyncVarEither_ var . fmap Right)
+  -- The resulting async is kept in memory by 'x' or 'y' until one of them completes.
+  onResult_ var (const (pure ())) (const (disposeCallbackEventually d1 *> disposeCallbackEventually d2))
+  pure $ toAsync var
 
 
 -- * Cancellation
 
 newtype CancellationToken = CancellationToken (AsyncVar Void)
 
+instance IsAsync Void CancellationToken where
+  toAsync (CancellationToken var) = toAsync var
+
 newCancellationToken :: IO CancellationToken
 newCancellationToken = CancellationToken <$> newAsyncVar
 
 cancel :: Exception e => CancellationToken -> e -> IO ()
-cancel (CancellationToken var) = failAsyncVar var . toException
+cancel (CancellationToken var) = failAsyncVar_ var . toException
 
 isCancellationRequested :: CancellationToken -> IO Bool
 isCancellationRequested (CancellationToken var) = isJust <$> peekAsync var
@@ -280,6 +376,10 @@ throwIfCancellationRequested (CancellationToken var) =
     Just (Left ex) -> throwIO ex
     _ -> pure ()
 
+awaitUnlessCancellationRequested :: IsAsync a b => CancellationToken -> b -> AsyncIO a
+awaitUnlessCancellationRequested cancellationToken = fmap (either absurd id) . awaitEither cancellationToken . toAsync
+
+
 withCancellationToken :: (CancellationToken -> IO a) -> IO a
 withCancellationToken action = do
   cancellationToken <- newCancellationToken
@@ -289,7 +389,8 @@ withCancellationToken action = do
     void $ forkIOWithUnmask $ \threadUnmask -> do
       putMVar resultMVar =<< try (threadUnmask (action cancellationToken))
 
-    either throwIO pure =<< (unmask (takeMVar resultMVar) `catchAll` (\ex -> cancel cancellationToken ex >> takeMVar resultMVar))
+    -- TODO test if it is better to run readMVar recursively or to keep it uninterruptible
+    either throwIO pure =<< (unmask (readMVar resultMVar) `catchAll` (\ex -> cancel cancellationToken ex >> readMVar resultMVar))
 
 
 -- * Disposable
@@ -300,21 +401,22 @@ class IsDisposable a where
   -- | Dispose a resource.
   dispose :: a -> AsyncIO ()
 
+  -- | Dispose a resource in the IO monad.
+  disposeIO :: a -> IO ()
+
   toDisposable :: a -> Disposable
   toDisposable = mkDisposable . dispose
 
--- | Dispose a resource in the IO monad.
-disposeIO :: IsDisposable a => a -> IO ()
-disposeIO = runAsyncIO . dispose
-
 instance IsDisposable a => IsDisposable (Maybe a) where
   dispose = mapM_ dispose
+  disposeIO = mapM_ disposeIO
 
 
 newtype Disposable = Disposable (AsyncIO ())
 
 instance IsDisposable Disposable where
   dispose (Disposable fn) = fn
+  disposeIO = runAsyncIO . dispose
   toDisposable = id
 
 instance Semigroup Disposable where
@@ -333,3 +435,21 @@ synchronousDisposable = mkDisposable . liftIO
 
 noDisposable :: Disposable
 noDisposable = mempty
+
+
+
+data CallbackDisposable = CallbackDisposable (IO ()) (IO ())
+
+instance IsDisposable CallbackDisposable where
+  dispose = liftIO . disposeCallback
+  disposeIO = disposeCallback
+  toDisposable = Disposable . dispose
+
+disposeCallback :: CallbackDisposable -> IO ()
+disposeCallback (CallbackDisposable f _) = f
+
+disposeCallbackEventually :: CallbackDisposable -> IO ()
+disposeCallbackEventually (CallbackDisposable _ e) = e
+
+noCallbackDisposable :: CallbackDisposable
+noCallbackDisposable = CallbackDisposable mempty mempty
diff --git a/test/Quasar/AsyncSpec.hs b/test/Quasar/AsyncSpec.hs
index ace2f49..55c4684 100644
--- a/test/Quasar/AsyncSpec.hs
+++ b/test/Quasar/AsyncSpec.hs
@@ -80,3 +80,9 @@ spec = parallel $ do
         -- Use bind to create an AsyncIOPlumbing, which is the interesting case that uses `uninterruptibleMask` when run
         await never >>= pure
       result `shouldBe` Nothing
+
+  describe "CancellationToken" $ do
+    it "can be waited upon" $ do
+      result <- timeout 100000 $ withCancellationToken wait
+      result `shouldBe` Nothing -- `wait` re-throws the exception
+
-- 
GitLab