diff --git a/src/Quasar/Core.hs b/src/Quasar/Core.hs index a8375166ef2c4451bb33dea0a5593ff620e87161..17826fa4453562510bc39088e30d9c42ad4d434b 100644 --- a/src/Quasar/Core.hs +++ b/src/Quasar/Core.hs @@ -30,8 +30,11 @@ module Quasar.Core ( noDisposable, ) where +import Control.Concurrent (forkIOWithUnmask) +import Control.Exception (MaskingState(..), getMaskingState) import Control.Monad.Catch import Data.HashMap.Strict qualified as HM +import Data.Maybe (isJust) import Quasar.Prelude -- * Async @@ -111,16 +114,13 @@ data AsyncIO r | AsyncIOPlumbing (IO (AsyncIO r)) instance Functor AsyncIO where - fmap fn (AsyncIOSuccess x) = AsyncIOSuccess (fn x) - fmap _ (AsyncIOFailure x) = AsyncIOFailure x - fmap fn (AsyncIOIO x) = AsyncIOIO (fn <$> x) - fmap fn (AsyncIOAsync x) = bindAsync x (pure . fn) - fmap fn (AsyncIOPlumbing x) = AsyncIOPlumbing (fn <<$>> x) + fmap fn = (>>= pure . fn) instance Applicative AsyncIO where pure = AsyncIOSuccess (<*>) pf px = pf >>= \f -> f <$> px liftA2 f px py = px >>= \x -> f x <$> py + instance Monad AsyncIO where (>>=) :: forall a b. AsyncIO a -> (a -> AsyncIO b) -> AsyncIO b (>>=) (AsyncIOSuccess x) fn = fn x @@ -141,14 +141,14 @@ instance MonadCatch AsyncIO where catch x@(AsyncIOFailure ex) handler = maybe x handler (fromException ex) catch (AsyncIOIO x) handler = AsyncIOIO (try x) >>= handleEither handler catch (AsyncIOAsync x) handler = bindAsyncCatch x (handleEither handler) - catch (AsyncIOPlumbing x) handler = AsyncIOPlumbing $ x >>= pure . (`catch` handler) + catch (AsyncIOPlumbing x) handler = AsyncIOPlumbing $ (`catch` handler) <$> x handleEither :: Exception e => (e -> AsyncIO a) -> Either SomeException a -> AsyncIO a handleEither handler (Left ex) = maybe (AsyncIOFailure ex) handler (fromException ex) handleEither _ (Right r) = pure r bindAsync :: forall a b. Async a -> (a -> AsyncIO b) -> AsyncIO b -bindAsync x fn = bindAsyncCatch x (either (AsyncIOFailure) fn) +bindAsync x fn = bindAsyncCatch x (either AsyncIOFailure fn) bindAsyncCatch :: forall a b. Async a -> (Either SomeException a -> AsyncIO b) -> AsyncIO b bindAsyncCatch x fn = AsyncIOPlumbing $ newAsyncVar >>= bindAsync' @@ -184,7 +184,8 @@ runAsyncIO (AsyncIOSuccess x) = pure x runAsyncIO (AsyncIOFailure x) = throwIO x runAsyncIO (AsyncIOIO x) = x runAsyncIO (AsyncIOAsync x) = wait x -runAsyncIO (AsyncIOPlumbing x) = x >>= runAsyncIO -- TODO error handling +runAsyncIO (AsyncIOPlumbing x) = do + x >>= runAsyncIO awaitResult :: AsyncIO (Async r) -> AsyncIO r awaitResult = (await =<<) @@ -209,7 +210,9 @@ mapAsync fn = async . fmap fn . await -- | The default implementation for a `Async` that can be fulfilled later. newtype AsyncVar r = AsyncVar (MVar (AsyncVarState r)) -data AsyncVarState r = AsyncVarCompleted (Either SomeException r) | AsyncVarOpen (HM.HashMap Unique (Either SomeException r -> IO (), SomeException -> IO ())) +data AsyncVarState r + = AsyncVarCompleted (Either SomeException r) + | AsyncVarOpen (HM.HashMap Unique (Either SomeException r -> IO (), SomeException -> IO ())) instance IsAsync r (AsyncVar r) where peekAsync :: AsyncVar r -> IO (Maybe (Either SomeException r)) @@ -255,6 +258,40 @@ putAsyncVarEither (AsyncVar mvar) value = liftIO $ do putMVar mvar (AsyncVarCompleted value) +-- * Cancellation + +newtype CancellationToken = CancellationToken (AsyncVar Void) + +newCancellationToken :: IO CancellationToken +newCancellationToken = CancellationToken <$> newAsyncVar + +cancel :: Exception e => CancellationToken -> e -> IO () +cancel (CancellationToken var) = failAsyncVar var . toException + +isCancellationRequested :: CancellationToken -> IO Bool +isCancellationRequested (CancellationToken var) = isJust <$> peekAsync var + +cancellationState :: CancellationToken -> IO (Maybe SomeException) +cancellationState (CancellationToken var) = (either Just (const Nothing) =<<) <$> peekAsync var + +throwIfCancellationRequested :: CancellationToken -> IO () +throwIfCancellationRequested (CancellationToken var) = + peekAsync var >>= \case + Just (Left ex) -> throwIO ex + _ -> pure () + +withCancellationToken :: (CancellationToken -> IO a) -> IO a +withCancellationToken action = do + cancellationToken <- newCancellationToken + resultMVar :: MVar (Either SomeException a) <- newEmptyMVar + + uninterruptibleMask $ \unmask -> do + void $ forkIOWithUnmask $ \threadUnmask -> do + putMVar resultMVar =<< try (threadUnmask (action cancellationToken)) + + either throwIO pure =<< (unmask (takeMVar resultMVar) `catchAll` (\ex -> cancel cancellationToken ex >> takeMVar resultMVar)) + + -- * Disposable class IsDisposable a where diff --git a/test/Quasar/AsyncSpec.hs b/test/Quasar/AsyncSpec.hs index 15bad2b20fc5ffdc24826c6d8ca565e10b5693a2..ace2f4990d5160f04c87184c0e7dbe809615187b 100644 --- a/test/Quasar/AsyncSpec.hs +++ b/test/Quasar/AsyncSpec.hs @@ -8,6 +8,7 @@ import Data.Either (isRight) import Prelude import Test.Hspec import Quasar.Core +import System.Timeout shouldSatisfyM :: (HasCallStack, Show a) => IO a -> (a -> Bool) -> Expectation shouldSatisfyM action expected = action >>= (`shouldSatisfy` expected) @@ -71,3 +72,11 @@ spec = parallel $ do threadDelay 100000 putAsyncVar avar () runAsyncIO (await avar >>= pure) + + it "can terminate when encountering an asynchronous exception" $ do + never <- newAsyncVar :: IO (AsyncVar ()) + + result <- timeout 100000 $ runAsyncIO $ + -- Use bind to create an AsyncIOPlumbing, which is the interesting case that uses `uninterruptibleMask` when run + await never >>= pure + result `shouldBe` Nothing