From b22a42e6f17583d0f55f294f7ab661be1c2de0ae Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Mon, 23 Aug 2021 01:38:02 +0200
Subject: [PATCH] Implement Task cancellation

---
 src/Quasar/Async.hs     | 71 ++++++++++++++++++++++++++++++-----------
 src/Quasar/Awaitable.hs |  4 +++
 2 files changed, 56 insertions(+), 19 deletions(-)

diff --git a/src/Quasar/Async.hs b/src/Quasar/Async.hs
index 0f9bfa9..a28801f 100644
--- a/src/Quasar/Async.hs
+++ b/src/Quasar/Async.hs
@@ -26,7 +26,7 @@ module Quasar.Async (
   unlimitedAsyncManagerConfiguration,
 ) where
 
-import Control.Concurrent (ThreadId, forkIOWithUnmask)
+import Control.Concurrent (ThreadId, forkIOWithUnmask, throwTo)
 import Control.Concurrent.STM
 import Control.Monad.Catch
 import Control.Monad.Reader
@@ -46,16 +46,47 @@ async :: MonadAsync m => AsyncIO r -> m (Task r)
 async action = asyncWithUnmask (\unmask -> unmask action)
 
 -- | Run the synchronous part of an `AsyncIO` and then return an `Awaitable` that can be used to wait for completion of the synchronous part.
+--
+-- The action will be run with asynchronous exceptions masked and will be passed an action that can be used unmask.
 asyncWithUnmask :: MonadAsync m => ((forall a. AsyncIO a -> AsyncIO a) -> AsyncIO r) -> m (Task r)
 -- TODO resource limits
 asyncWithUnmask action = do
   asyncManager <- askAsyncManager
-  resultVar <- newAsyncVar
-  liftIO $ mask_ $ do
-    void $ forkIOWithUnmask $ \unmask -> do
-      result <- try $ runOnAsyncManager asyncManager (action (liftUnmask unmask))
-      putAsyncVarEither_ resultVar result
-    pure $ Task (toAwaitable resultVar)
+
+  liftIO $ mask_ do
+    resultVar <- newAsyncVar
+    threadIdVar <- newEmptyTMVarIO
+
+    disposable <- attachDisposeAction (getResourceManager asyncManager) (disposeTask threadIdVar resultVar)
+
+    onException
+      do
+        atomically . putTMVar threadIdVar . Just =<<
+          forkIOWithUnmask \unmask -> do
+            result <- try $ catch
+              do runOnAsyncManager asyncManager (action (liftUnmask unmask))
+              \CancelTask -> throwIO TaskDisposed
+
+            putAsyncVarEither_ resultVar result
+
+            -- Thread has completed work, "disarm" the disposable and fire it
+            void $ atomically $ swapTMVar threadIdVar Nothing
+            disposeIO disposable
+
+      do atomically $ putTMVar threadIdVar Nothing
+
+    pure $ Task disposable (toAwaitable resultVar)
+  where
+    disposeTask :: TMVar (Maybe ThreadId) -> AsyncVar r -> IO (Awaitable ())
+    disposeTask threadIdVar resultVar = mask_ do
+      -- Blocks until the thread is forked
+      atomically (swapTMVar threadIdVar Nothing) >>= \case
+        -- Thread completed or initialization failed
+        Nothing -> pure ()
+        Just threadId -> throwTo threadId CancelTask
+
+      -- Wait for task completion or failure. Tasks must not ignore `CancelTask` or this will hang.
+      pure $ mapAwaitable (const $ pure ()) resultVar
 
 liftUnmask :: (IO a -> IO a) -> AsyncIO a -> AsyncIO a
 liftUnmask unmask action = do
@@ -87,7 +118,7 @@ data AsyncManager = AsyncManager {
 }
 
 instance IsDisposable AsyncManager where
-  toDisposable = undefined
+  toDisposable = toDisposable . getResourceManager
 
 instance HasResourceManager AsyncManager where
   getResourceManager = resourceManager
@@ -97,20 +128,20 @@ instance HasResourceManager AsyncManager where
 -- The result (or exception) can be aquired by using the `IsAwaitable` class (e.g. by calling `await` or `awaitIO`).
 -- It might be possible to cancel the task by using the `IsDisposable` class if the operation has not been completed.
 -- If the result is no longer required the task should be cancelled, to avoid leaking memory.
-newtype Task r = Task (Awaitable r)
+data Task r = Task Disposable (Awaitable r)
 
 instance IsAwaitable r (Task r) where
-  toAwaitable (Task awaitable) = awaitable
+  toAwaitable (Task _ awaitable) = awaitable
 
 instance IsDisposable (Task r) where
-  toDisposable = undefined
+  toDisposable (Task disposable _) = disposable
 
 instance Functor Task where
-  fmap fn (Task x) = Task (fn <$> x)
+  fmap fn (Task disposable awaitable) = Task disposable (fn <$> awaitable)
 
 instance Applicative Task where
-  pure = Task . pure
-  liftA2 fn (Task fx) (Task fy) = Task $ liftA2 fn fx fy
+  pure value = Task noDisposable (pure value)
+  liftA2 fn (Task dx fx) (Task dy fy) = Task (dx <> dy) $ liftA2 fn fx fy
 
 cancelTask :: Task r -> IO (Awaitable ())
 cancelTask = dispose
@@ -121,17 +152,17 @@ cancelTaskIO = awaitIO <=< dispose
 -- | Creates an `Task` from an `Awaitable`.
 -- The resulting task only depends on an external resource, so disposing it has no effect.
 toTask :: IsAwaitable r a => a -> Task r
-toTask = Task . toAwaitable
+toTask result = Task noDisposable (toAwaitable result)
 
 completedTask :: Either SomeException r -> Task r
-completedTask = toTask . completedAwaitable
+completedTask result = Task noDisposable (completedAwaitable result)
 
 -- | Alias for `pure`
 successfulTask :: r -> Task r
 successfulTask = pure
 
 failedTask :: SomeException -> Task r
-failedTask = toTask . failedAwaitable
+failedTask ex = Task noDisposable (failedAwaitable ex)
 
 
 
@@ -139,9 +170,9 @@ data CancelTask = CancelTask
   deriving stock Show
 instance Exception CancelTask where
 
-data CancelledTask = CancelledTask
+data TaskDisposed = TaskDisposed
   deriving stock Show
-instance Exception CancelledTask where
+instance Exception TaskDisposed where
 
 
 data AsyncManagerConfiguraiton = AsyncManagerConfiguraiton {
@@ -172,8 +203,10 @@ withUnlimitedAsyncManager = withAsyncManager unlimitedAsyncManagerConfiguration
 
 newAsyncManager :: AsyncManagerConfiguraiton -> IO AsyncManager
 newAsyncManager configuration = do
+  resourceManager <- newResourceManager
   threads <- newTVarIO mempty
   pure AsyncManager {
+    resourceManager,
     configuration,
     threads
   }
diff --git a/src/Quasar/Awaitable.hs b/src/Quasar/Awaitable.hs
index c182c80..e0f3941 100644
--- a/src/Quasar/Awaitable.hs
+++ b/src/Quasar/Awaitable.hs
@@ -9,6 +9,7 @@ module Quasar.Awaitable (
   failedAwaitable,
   completedAwaitable,
   simpleAwaitable,
+  mapAwaitable,
 
   -- * Awaiting multiple awaitables
   cacheAwaitable,
@@ -91,6 +92,9 @@ failedAwaitable = completedAwaitable . Left
 simpleAwaitable :: STM (Maybe (Either SomeException a)) -> Awaitable a
 simpleAwaitable query = Awaitable (querySTM query)
 
+mapAwaitable :: IsAwaitable i a => (Either SomeException i -> Either SomeException r) -> a -> Awaitable r
+mapAwaitable fn awaitable = Awaitable $ fn <$> runAwaitable awaitable
+
 
 class Monad m => MonadQuerySTM m where
   querySTM :: (forall a. STM (Maybe a) -> m a)
-- 
GitLab