From 94db33bd5ccb54abf863a7ef0dc20ffae17e130c Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Sun, 25 Jul 2021 18:17:54 +0200
Subject: [PATCH] Implement Awaitable based on stm (replacing Async)

Co-authored-by: Jan Beinke <git@janbeinke.com>
---
 quasar.cabal             |   2 +
 src/Quasar/Core.hs       | 336 ++++++++++++++-------------------------
 test/Quasar/AsyncSpec.hs |  25 +--
 3 files changed, 133 insertions(+), 230 deletions(-)

diff --git a/quasar.cabal b/quasar.cabal
index 1d32814..ab58ee9 100644
--- a/quasar.cabal
+++ b/quasar.cabal
@@ -71,6 +71,7 @@ library
     microlens-platform,
     mtl,
     record-hasfield,
+    stm,
     template-haskell,
     transformers,
     unordered-containers,
@@ -93,6 +94,7 @@ test-suite quasar-test
     base >=4.7 && <5,
     hspec,
     quasar,
+    stm,
     unordered-containers,
   main-is: Spec.hs
   other-modules:
diff --git a/src/Quasar/Core.hs b/src/Quasar/Core.hs
index 81eed95..cf08fd3 100644
--- a/src/Quasar/Core.hs
+++ b/src/Quasar/Core.hs
@@ -1,10 +1,17 @@
 module Quasar.Core (
-  -- * Async
-  IsAsync(..),
-  Async,
-  successfulAsync,
-  failedAsync,
-  completedAsync,
+  -- * Awaitable
+  IsAwaitable(..),
+  awaitSTM,
+  Awaitable,
+  successfulAwaitable,
+  failedAwaitable,
+  completedAwaitable,
+  peekAwaitable,
+
+  -- * AsyncVar
+  AsyncVar,
+  newAsyncVar,
+  putAsyncVar,
 
   -- * AsyncIO
   AsyncIO,
@@ -13,14 +20,6 @@ module Quasar.Core (
   runAsyncIO,
   awaitResult,
 
-  -- * Async helpers
-  mapAsync,
-
-  -- * AsyncVar
-  AsyncVar,
-  newAsyncVar,
-  putAsyncVar,
-
   -- * Disposable
   IsDisposable(..),
   Disposable,
@@ -33,90 +32,73 @@ module Quasar.Core (
 ) where
 
 import Control.Concurrent (forkIOWithUnmask)
+import Control.Concurrent.STM
 import Control.Exception (MaskingState(..), getMaskingState)
 import Control.Monad.Catch
-import Data.HashMap.Strict qualified as HM
 import Data.Maybe (isJust)
 import Data.Void (absurd)
 import Quasar.Prelude
 
--- * Async
+-- * Awaitable
+
+class IsAwaitable r a | a -> r where
+  peekSTM :: a -> STM (Maybe (Either SomeException r))
+  peekSTM = peekSTM . toAwaitable
+
+  toAwaitable :: a -> Awaitable r
+  toAwaitable = SomeAwaitable
 
-class IsAsync r a | a -> r where
-  -- | Wait until the promise is settled and return the result.
-  wait :: a -> IO r
-  wait = wait . toAsync
+  {-# MINIMAL toAwaitable | peekSTM #-}
 
-  peekAsync :: a -> IO (Maybe (Either SomeException r))
-  peekAsync = peekAsync . toAsync
 
-  -- | Register a callback, that will be called once the promise is settled.
-  -- If the promise is already settled, the callback will be called immediately instead.
-  --
-  -- The returned `Disposable` can be used to deregister the callback.
-  --
-  -- 'onResult' should not throw.
-  onResult
-    :: a
-    -- ^ async
-    -> (SomeException -> IO ())
-    -- ^ callback exception handler
-    -> (Either SomeException r -> IO ())
-    -- ^ callback
-    -> IO CallbackDisposable
-  onResult x ceh c = onResult (toAsync x) ceh c
+-- | Wait until the promise is settled and return the result.
+awaitSTM :: IsAwaitable r a => a -> STM (Either SomeException r)
+awaitSTM = peekSTM >=> maybe retry pure
 
-  onResult_
-    :: a
-    -> (SomeException -> IO ())
-    -> (Either SomeException r -> IO ())
-    -> IO ()
-  onResult_ x ceh c = onResult_ (toAsync x) ceh c
 
-  toAsync :: a -> Async r
-  toAsync = SomeAsync
+data Awaitable r = forall a. IsAwaitable r a => SomeAwaitable a
 
-  {-# MINIMAL toAsync | (wait, peekAsync, onResult, onResult_) #-}
+instance IsAwaitable r (Awaitable r) where
+  peekSTM (SomeAwaitable x) = peekSTM x
+  toAwaitable = id
 
+instance Functor Awaitable where
+  fmap fn = toAwaitable . FnAwaitable . fmap (fmap (fmap fn)) . peekSTM
 
-data Async r = forall a. IsAsync r a => SomeAsync a
 
-instance IsAsync r (Async r) where
-  wait (SomeAsync x) = wait x
-  onResult (SomeAsync x) y = onResult x y
-  onResult_ (SomeAsync x) y = onResult_ x y
-  peekAsync (SomeAsync x) = peekAsync x
 
-instance Functor Async where
-  fmap fn = toAsync . MappedAsync fn
+newtype CompletedAwaitable r = CompletedAwaitable (Either SomeException r)
+instance IsAwaitable r (CompletedAwaitable r) where
+  peekSTM (CompletedAwaitable value) = pure $ Just value
 
+completedAwaitable :: Either SomeException r -> Awaitable r
+completedAwaitable = toAwaitable . CompletedAwaitable
 
+successfulAwaitable :: r -> Awaitable r
+successfulAwaitable = completedAwaitable . Right
 
-newtype CompletedAsync r = CompletedAsync (Either SomeException r)
-instance IsAsync r (CompletedAsync r) where
-  wait (CompletedAsync value) = either throwIO pure value
-  onResult (CompletedAsync value) callbackExceptionHandler callback =
-    noCallbackDisposable <$ (callback value `catch` callbackExceptionHandler)
-  onResult_ (CompletedAsync value) callbackExceptionHandler callback =
-    callback value `catch` callbackExceptionHandler
-  peekAsync (CompletedAsync value) = pure $ Just value
+failedAwaitable :: SomeException -> Awaitable r
+failedAwaitable = completedAwaitable . Left
 
-completedAsync :: Either SomeException r -> Async r
-completedAsync = toAsync . CompletedAsync
 
-successfulAsync :: r -> Async r
-successfulAsync = completedAsync . Right
+peekAwaitable :: (IsAwaitable r a, MonadIO m) => a -> m (Maybe (Either SomeException r))
+peekAwaitable = liftIO . atomically . peekSTM
 
-failedAsync :: SomeException -> Async r
-failedAsync = completedAsync . Left
 
+newtype FnAwaitable r = FnAwaitable (STM (Maybe (Either SomeException r)))
+instance IsAwaitable r (FnAwaitable r) where
+  peekSTM (FnAwaitable fn) = fn
 
-data MappedAsync r = forall a. MappedAsync (a -> r) (Async a)
-instance IsAsync r (MappedAsync r) where
-  wait (MappedAsync fn x) = fn <$> wait x
-  peekAsync (MappedAsync fn x) = fmap fn <<$>> peekAsync x
-  onResult (MappedAsync fn x) callbackExceptionHandler callback = onResult x callbackExceptionHandler $ callback . fmap fn
-  onResult_ (MappedAsync fn x) callbackExceptionHandler callback = onResult_ x callbackExceptionHandler $ callback . fmap fn
+awaitableSTM :: STM (Maybe (Either SomeException r)) -> IO (Awaitable r)
+awaitableSTM fn = do
+  cache <- newTVarIO (Left fn)
+  pure . toAwaitable . FnAwaitable $
+    readTVar cache >>= \case
+      Left generatorFn -> do
+        value <- generatorFn
+        writeTVar cache (Right value)
+        pure value
+      Right value -> pure value
 
 
 -- * AsyncIO
@@ -125,7 +107,7 @@ data AsyncIO r
   = AsyncIOSuccess r
   | AsyncIOFailure SomeException
   | AsyncIOIO (IO r)
-  | AsyncIOAsync (Async r)
+  | AsyncIOAsync (Awaitable r)
   | AsyncIOPlumbing (MaskingState -> CancellationToken -> IO (AsyncIO r))
 
 instance Functor AsyncIO where
@@ -171,62 +153,57 @@ handleEither _ (Right r) = pure r
 mapPlumbing :: (MaskingState -> CancellationToken -> IO (AsyncIO a)) -> (IO (AsyncIO a) -> IO (AsyncIO b)) -> AsyncIO b
 mapPlumbing plumbing fn = AsyncIOPlumbing $ \maskingState cancellationToken -> fn (plumbing maskingState cancellationToken)
 
-bindAsync :: forall a b. Async a -> (a -> AsyncIO b) -> AsyncIO b
+bindAsync :: forall a b. Awaitable a -> (a -> AsyncIO b) -> AsyncIO b
 bindAsync x fn = bindAsyncCatch x (either AsyncIOFailure fn)
 
-bindAsyncCatch :: forall a b. Async a -> (Either SomeException a -> AsyncIO b) -> AsyncIO b
-bindAsyncCatch x fn = AsyncIOPlumbing $ \maskingState cancellationToken -> do
-  var <- newAsyncVar
-  disposableMVar <- newEmptyMVar
-  go maskingState cancellationToken var disposableMVar
-  where
-    go maskingState cancellationToken var disposableMVar = do
-      disposable <- onResult x (failAsyncVar_ var) $ \x -> do
-        (putAsyncIOResult . fn) x
-      -- TODO update mvar and dispose when completed
-      putMVar disposableMVar disposable
-      pure $ awaitUnlessCancellationRequested cancellationToken var
-      where
-        put = putAsyncVarEither var
-        putAsyncIOResult :: AsyncIO b -> IO ()
-        putAsyncIOResult (AsyncIOSuccess x) = put (Right x)
-        putAsyncIOResult (AsyncIOFailure x) = put (Left x)
-        putAsyncIOResult (AsyncIOIO x) = try x >>= put
-        putAsyncIOResult (AsyncIOAsync x) = onResult_ x (put . Left) put
-        putAsyncIOResult (AsyncIOPlumbing x) = x maskingState cancellationToken >>= putAsyncIOResult
-
-
-
--- | Run the synchronous part of an `AsyncIO` and then return an `Async` that can be used to wait for completion of the synchronous part.
-async :: AsyncIO r -> AsyncIO (Async r)
-async (AsyncIOSuccess x) = pure $ successfulAsync x
-async (AsyncIOFailure x) = pure $ failedAsync x
-async (AsyncIOIO x) = liftIO $ either failedAsync successfulAsync <$> try x
+bindAsyncCatch :: forall a b. Awaitable a -> (Either SomeException a -> AsyncIO b) -> AsyncIO b
+bindAsyncCatch x fn = undefined -- AsyncIOPlumbing $ \maskingState cancellationToken -> do
+  --var <- newAsyncVar
+  --disposableMVar <- newEmptyMVar
+  --go maskingState cancellationToken var disposableMVar
+  --where
+  --  go maskingState cancellationToken var disposableMVar = do
+  --    disposable <- onResult x (failAsyncVar_ var) $ \x -> do
+  --      (putAsyncIOResult . fn) x
+  --    -- TODO update mvar and dispose when completed
+  --    putMVar disposableMVar disposable
+  --    pure $ awaitUnlessCancellationRequested cancellationToken var
+  --    where
+  --      put = putAsyncVarEither var
+  --      putAsyncIOResult :: AsyncIO b -> IO ()
+  --      putAsyncIOResult (AsyncIOSuccess x) = put (Right x)
+  --      putAsyncIOResult (AsyncIOFailure x) = put (Left x)
+  --      putAsyncIOResult (AsyncIOIO x) = try x >>= put
+  --      putAsyncIOResult (AsyncIOAsync x) = onResult_ x (put . Left) put
+  --      putAsyncIOResult (AsyncIOPlumbing x) = x maskingState cancellationToken >>= putAsyncIOResult
+
+
+
+-- | Run the synchronous part of an `AsyncIO` and then return an `Awaitable` that can be used to wait for completion of the synchronous part.
+async :: AsyncIO r -> AsyncIO (Awaitable r)
+async (AsyncIOSuccess x) = pure $ successfulAwaitable x
+async (AsyncIOFailure x) = pure $ failedAwaitable x
+async (AsyncIOIO x) = liftIO $ either failedAwaitable successfulAwaitable <$> try x
 async (AsyncIOAsync x) = pure x -- TODO caching
 async (AsyncIOPlumbing x) = mapPlumbing x (fmap async)
 
-await :: IsAsync r a => a -> AsyncIO r
-await = AsyncIOAsync . toAsync
+await :: IsAwaitable r a => a -> AsyncIO r
+await = AsyncIOAsync . toAwaitable
 
 -- | Run an `AsyncIO` to completion and return the result.
 runAsyncIO :: AsyncIO r -> IO r
 runAsyncIO (AsyncIOSuccess x) = pure x
 runAsyncIO (AsyncIOFailure x) = throwIO x
 runAsyncIO (AsyncIOIO x) = x
-runAsyncIO (AsyncIOAsync x) = wait x
+runAsyncIO (AsyncIOAsync x) = either throwIO pure =<< atomically (awaitSTM x)
 runAsyncIO (AsyncIOPlumbing x) = do
   maskingState <- getMaskingState
   withCancellationToken $ x maskingState >=> runAsyncIO
 
-awaitResult :: AsyncIO (Async r) -> AsyncIO r
+awaitResult :: AsyncIO (Awaitable r) -> AsyncIO r
 awaitResult = (await =<<)
 
 
-mapAsync :: (a -> b) -> Async a -> AsyncIO (Async b)
--- FIXME: don't actually attach a function if the resulting async is not used
--- maybe use `Weak`? When `Async b` is GC'ed, the handler is detached from `Async a`
-mapAsync fn = async . fmap fn . await
-
 
 -- ** Forking asyncs
 
@@ -239,77 +216,24 @@ mapAsync fn = async . fmap fn . await
 
 -- ** AsyncVar
 
--- | The default implementation for a `Async` that can be fulfilled later.
-newtype AsyncVar r = AsyncVar (MVar (AsyncVarState r))
-data AsyncVarState r
-  = AsyncVarCompleted (Either SomeException r) (IO ())
-  | AsyncVarOpen (HM.HashMap Unique (Either SomeException r -> IO (), SomeException -> IO ()))
-
-instance IsAsync r (AsyncVar r) where
-  wait x = do
-    mvar <- newEmptyMVar
-    onResult_ x (void . tryPutMVar mvar . Left) (resultCallback mvar)
-    readMVar mvar >>= either throwIO pure
-    where
-      resultCallback :: MVar (Either SomeException r) -> Either SomeException r -> IO ()
-      resultCallback mvar result = do
-        success <- tryPutMVar mvar result
-        unless success $ fail "Callback was called multiple times"
-
-  peekAsync :: AsyncVar r -> IO (Maybe (Either SomeException r))
-  peekAsync (AsyncVar mvar) = readMVar mvar >>= pure . \case
-    AsyncVarCompleted x _ -> Just x
-    AsyncVarOpen _ -> Nothing
-
-  onResult :: AsyncVar r -> (SomeException -> IO ()) -> (Either SomeException r -> IO ()) -> IO CallbackDisposable
-  onResult (AsyncVar mvar) callbackExceptionHandler callback =
-    modifyMVar mvar $ \case
-      AsyncVarOpen callbacks -> do
-        key <- newUnique
-        pure (AsyncVarOpen (HM.insert key (callback, callbackExceptionHandler) callbacks), callbackDisposable key)
-      x@(AsyncVarCompleted value _) -> (x, noCallbackDisposable) <$ callback value `catch` callbackExceptionHandler
-    where
-      callbackDisposable :: Unique -> CallbackDisposable
-      callbackDisposable key = CallbackDisposable removeHandler removeHandlerEventually
-        where
-          removeHandler = do
-            waitForCallbacks <- modifyMVar mvar $ pure . \case
-              x@(AsyncVarCompleted _ waitForCallbacks) -> (x, waitForCallbacks)
-              AsyncVarOpen x -> (AsyncVarOpen (HM.delete key x), pure ())
-            -- Dispose should only return after the callback can't be called any longer
-            -- If the callbacks are already being dispatched, wait for them to complete to keep the guarantee
-            waitForCallbacks
-
-          removeHandlerEventually =
-            modifyMVar_ mvar $ pure . \case
-              x@(AsyncVarCompleted _ _) -> x
-              AsyncVarOpen x -> AsyncVarOpen $ HM.delete key x
-
-  onResult_ x y = void . onResult x y
+-- | The default implementation for an `Awaitable` that can be fulfilled later.
+newtype AsyncVar r = AsyncVar (TMVar (Either SomeException r))
+
+instance IsAwaitable r (AsyncVar r) where
+  peekSTM (AsyncVar var) = tryReadTMVar var
+
+tryPutAsyncVarEitherSTM :: AsyncVar a -> Either SomeException a -> STM Bool
+tryPutAsyncVarEitherSTM (AsyncVar var) = tryPutTMVar var
 
 tryPutAsyncVarEither :: forall a m. MonadIO m => AsyncVar a -> Either SomeException a -> m Bool
-tryPutAsyncVarEither (AsyncVar mvar) value = liftIO $ do
-  action <- modifyMVar mvar $ \case
-    x@(AsyncVarCompleted _ waitForCallbacks) -> pure (x, False <$ waitForCallbacks)
-    AsyncVarOpen callbacksMap -> do
-      callbacksCompletedMVar <- newEmptyMVar
-      let waitForCallbacks = readMVar callbacksCompletedMVar
-          callbacks = HM.elems callbacksMap
-      pure (AsyncVarCompleted value waitForCallbacks, fireCallbacks callbacks callbacksCompletedMVar)
-
-  action
-
-  where
-    fireCallbacks :: [(Either SomeException a -> IO (), SomeException -> IO ())] -> MVar () -> IO Bool
-    fireCallbacks callbacks callbacksCompletedMVar = do
-      forM_ callbacks $ \(callback, callbackExceptionHandler) ->
-        callback value `catch` callbackExceptionHandler
-      putMVar callbacksCompletedMVar ()
-      pure True
+tryPutAsyncVarEither var = liftIO . atomically . tryPutAsyncVarEitherSTM var
 
 
+newAsyncVarSTM :: STM (AsyncVar r)
+newAsyncVarSTM = AsyncVar <$> newEmptyTMVar
+
 newAsyncVar :: MonadIO m => m (AsyncVar r)
-newAsyncVar = liftIO $ AsyncVar <$> newMVar (AsyncVarOpen HM.empty)
+newAsyncVar = liftIO $ AsyncVar <$> newEmptyTMVarIO
 
 
 putAsyncVar :: MonadIO m => AsyncVar a -> a -> m ()
@@ -338,25 +262,29 @@ tryPutAsyncVarEither_ var = void . tryPutAsyncVarEither var
 
 -- * Awaiting multiple asyncs
 
-awaitEither :: (IsAsync ra a , IsAsync rb b) => a -> b -> AsyncIO (Either ra rb)
+awaitEither :: (IsAwaitable ra a , IsAwaitable rb b) => a -> b -> AsyncIO (Either ra rb)
 awaitEither x y = AsyncIOPlumbing $ \_ _ -> AsyncIOAsync <$> awaitEitherPlumbing x y
 
-awaitEitherPlumbing :: (IsAsync ra a , IsAsync rb b) => a -> b -> IO (Async (Either ra rb))
-awaitEitherPlumbing x y = do
-  var <- newAsyncVar
-  d1 <- onResult x (failAsyncVar_ var) (tryPutAsyncVarEither_ var . fmap Left)
-  d2 <- onResult y (failAsyncVar_ var) (tryPutAsyncVarEither_ var . fmap Right)
-  -- The resulting async is kept in memory by 'x' or 'y' until one of them completes.
-  onResult_ var (const (pure ())) (const (disposeCallbackEventually d1 *> disposeCallbackEventually d2))
-  pure $ toAsync var
+awaitEitherPlumbing :: (IsAwaitable ra a , IsAwaitable rb b) => a -> b -> IO (Awaitable (Either ra rb))
+awaitEitherPlumbing x y = awaitableSTM $ peekEitherSTM x y
+
+peekEitherSTM :: (IsAwaitable ra a , IsAwaitable rb b) => a -> b -> STM (Maybe (Either SomeException (Either ra rb)))
+peekEitherSTM x y =
+  peekSTM x >>= \case
+    Just (Left ex) -> pure (Just (Left ex))
+    Just (Right r) -> pure (Just (Right (Left r)))
+    Nothing -> peekSTM y >>= \case
+      Just (Left ex) -> pure (Just (Left ex))
+      Just (Right r) -> pure (Just (Right (Right r)))
+      Nothing -> pure Nothing
 
 
 -- * Cancellation
 
 newtype CancellationToken = CancellationToken (AsyncVar Void)
 
-instance IsAsync Void CancellationToken where
-  toAsync (CancellationToken var) = toAsync var
+instance IsAwaitable Void CancellationToken where
+  toAwaitable (CancellationToken var) = toAwaitable var
 
 newCancellationToken :: IO CancellationToken
 newCancellationToken = CancellationToken <$> newAsyncVar
@@ -365,19 +293,19 @@ cancel :: Exception e => CancellationToken -> e -> IO ()
 cancel (CancellationToken var) = failAsyncVar_ var . toException
 
 isCancellationRequested :: CancellationToken -> IO Bool
-isCancellationRequested (CancellationToken var) = isJust <$> peekAsync var
+isCancellationRequested (CancellationToken var) = isJust <$> peekAwaitable var
 
 cancellationState :: CancellationToken -> IO (Maybe SomeException)
-cancellationState (CancellationToken var) = (either Just (const Nothing) =<<) <$> peekAsync var
+cancellationState (CancellationToken var) = (either Just (const Nothing) =<<) <$> peekAwaitable var
 
 throwIfCancellationRequested :: CancellationToken -> IO ()
 throwIfCancellationRequested (CancellationToken var) =
-  peekAsync var >>= \case
+  peekAwaitable var >>= \case
     Just (Left ex) -> throwIO ex
     _ -> pure ()
 
-awaitUnlessCancellationRequested :: IsAsync a b => CancellationToken -> b -> AsyncIO a
-awaitUnlessCancellationRequested cancellationToken = fmap (either absurd id) . awaitEither cancellationToken . toAsync
+awaitUnlessCancellationRequested :: IsAwaitable a b => CancellationToken -> b -> AsyncIO a
+awaitUnlessCancellationRequested cancellationToken = fmap (either absurd id) . awaitEither cancellationToken . toAwaitable
 
 
 withCancellationToken :: (CancellationToken -> IO a) -> IO a
@@ -435,21 +363,3 @@ synchronousDisposable = mkDisposable . liftIO
 
 noDisposable :: Disposable
 noDisposable = mempty
-
-
-
-data CallbackDisposable = CallbackDisposable (IO ()) (IO ())
-
-instance IsDisposable CallbackDisposable where
-  dispose = liftIO . disposeCallback
-  disposeIO = disposeCallback
-  toDisposable = Disposable . dispose
-
-disposeCallback :: CallbackDisposable -> IO ()
-disposeCallback (CallbackDisposable f _) = f
-
-disposeCallbackEventually :: CallbackDisposable -> IO ()
-disposeCallbackEventually (CallbackDisposable _ e) = e
-
-noCallbackDisposable :: CallbackDisposable
-noCallbackDisposable = CallbackDisposable mempty mempty
diff --git a/test/Quasar/AsyncSpec.hs b/test/Quasar/AsyncSpec.hs
index 55c4684..79b2b0b 100644
--- a/test/Quasar/AsyncSpec.hs
+++ b/test/Quasar/AsyncSpec.hs
@@ -1,10 +1,9 @@
 module Quasar.AsyncSpec (spec) where
 
 import Control.Concurrent
-import Control.Exception (throwIO)
-import Control.Monad (void)
+import Control.Concurrent.STM
+import Control.Monad (void, (<=<))
 import Control.Monad.IO.Class
-import Data.Either (isRight)
 import Prelude
 import Test.Hspec
 import Quasar.Core
@@ -24,17 +23,6 @@ spec = parallel $ do
       avar <- newAsyncVar :: IO (AsyncVar ())
       putAsyncVar avar ()
 
-    it "calls a callback" $ do
-      avar <- newAsyncVar :: IO (AsyncVar ())
-
-      mvar <- newEmptyMVar
-      onResult_ avar throwIO (putMVar mvar)
-
-      (() <$) <$> tryTakeMVar mvar `shouldReturn` Nothing
-
-      putAsyncVar avar ()
-      tryTakeMVar mvar `shouldSatisfyM` maybe False isRight
-
   describe "AsyncIO" $ do
     it "binds pure operations" $ do
       runAsyncIO (pure () >>= \() -> pure ())
@@ -82,7 +70,10 @@ spec = parallel $ do
       result `shouldBe` Nothing
 
   describe "CancellationToken" $ do
-    it "can be waited upon" $ do
-      result <- timeout 100000 $ withCancellationToken wait
-      result `shouldBe` Nothing -- `wait` re-throws the exception
+    it "propagates outer exceptions to the cancellation token" $ do
+      result <- timeout 100000 $ withCancellationToken (runAsyncIO . await)
+      result `shouldBe` Nothing
 
+    it "can return a value after cancellation" $ do
+      result <- timeout 100000 $ withCancellationToken (fmap (either (const True) (const False)) . atomically . awaitSTM)
+      result `shouldBe` Just True
-- 
GitLab