diff --git a/src/Quasar/Core.hs b/src/Quasar/Core.hs index 17826fa4453562510bc39088e30d9c42ad4d434b..81eed95899d207594d571b3e06b817b893c687f6 100644 --- a/src/Quasar/Core.hs +++ b/src/Quasar/Core.hs @@ -23,11 +23,13 @@ module Quasar.Core ( -- * Disposable IsDisposable(..), - disposeIO, Disposable, mkDisposable, synchronousDisposable, noDisposable, + + -- * Cancellation + withCancellationToken, ) where import Control.Concurrent (forkIOWithUnmask) @@ -35,6 +37,7 @@ import Control.Exception (MaskingState(..), getMaskingState) import Control.Monad.Catch import Data.HashMap.Strict qualified as HM import Data.Maybe (isJust) +import Data.Void (absurd) import Quasar.Prelude -- * Async @@ -42,22 +45,17 @@ import Quasar.Prelude class IsAsync r a | a -> r where -- | Wait until the promise is settled and return the result. wait :: a -> IO r - wait x = do - mvar <- newEmptyMVar - onResult_ x (void . tryPutMVar mvar . Left) (resultCallback mvar) - readMVar mvar >>= either throwIO pure - where - resultCallback :: MVar (Either SomeException r) -> Either SomeException r -> IO () - resultCallback mvar result = do - success <- tryPutMVar mvar result - unless success $ fail "Callback was called multiple times" + wait = wait . toAsync peekAsync :: a -> IO (Maybe (Either SomeException r)) + peekAsync = peekAsync . toAsync -- | Register a callback, that will be called once the promise is settled. -- If the promise is already settled, the callback will be called immediately instead. -- -- The returned `Disposable` can be used to deregister the callback. + -- + -- 'onResult' should not throw. onResult :: a -- ^ async @@ -65,18 +63,21 @@ class IsAsync r a | a -> r where -- ^ callback exception handler -> (Either SomeException r -> IO ()) -- ^ callback - -> IO Disposable + -> IO CallbackDisposable + onResult x ceh c = onResult (toAsync x) ceh c onResult_ :: a -> (SomeException -> IO ()) -> (Either SomeException r -> IO ()) -> IO () - onResult_ x y = void . onResult x y + onResult_ x ceh c = onResult_ (toAsync x) ceh c toAsync :: a -> Async r toAsync = SomeAsync + {-# MINIMAL toAsync | (wait, peekAsync, onResult, onResult_) #-} + data Async r = forall a. IsAsync r a => SomeAsync a @@ -86,12 +87,18 @@ instance IsAsync r (Async r) where onResult_ (SomeAsync x) y = onResult_ x y peekAsync (SomeAsync x) = peekAsync x +instance Functor Async where + fmap fn = toAsync . MappedAsync fn + newtype CompletedAsync r = CompletedAsync (Either SomeException r) instance IsAsync r (CompletedAsync r) where wait (CompletedAsync value) = either throwIO pure value - onResult (CompletedAsync value) callbackExceptionHandler callback = noDisposable <$ (callback value `catch` callbackExceptionHandler) + onResult (CompletedAsync value) callbackExceptionHandler callback = + noCallbackDisposable <$ (callback value `catch` callbackExceptionHandler) + onResult_ (CompletedAsync value) callbackExceptionHandler callback = + callback value `catch` callbackExceptionHandler peekAsync (CompletedAsync value) = pure $ Just value completedAsync :: Either SomeException r -> Async r @@ -104,6 +111,14 @@ failedAsync :: SomeException -> Async r failedAsync = completedAsync . Left +data MappedAsync r = forall a. MappedAsync (a -> r) (Async a) +instance IsAsync r (MappedAsync r) where + wait (MappedAsync fn x) = fn <$> wait x + peekAsync (MappedAsync fn x) = fmap fn <<$>> peekAsync x + onResult (MappedAsync fn x) callbackExceptionHandler callback = onResult x callbackExceptionHandler $ callback . fmap fn + onResult_ (MappedAsync fn x) callbackExceptionHandler callback = onResult_ x callbackExceptionHandler $ callback . fmap fn + + -- * AsyncIO data AsyncIO r @@ -111,10 +126,14 @@ data AsyncIO r | AsyncIOFailure SomeException | AsyncIOIO (IO r) | AsyncIOAsync (Async r) - | AsyncIOPlumbing (IO (AsyncIO r)) + | AsyncIOPlumbing (MaskingState -> CancellationToken -> IO (AsyncIO r)) instance Functor AsyncIO where - fmap fn = (>>= pure . fn) + fmap fn (AsyncIOSuccess x) = AsyncIOSuccess (fn x) + fmap _ (AsyncIOFailure x) = AsyncIOFailure x + fmap fn (AsyncIOIO x) = AsyncIOIO (fn <$> x) + fmap fn (AsyncIOAsync x) = AsyncIOAsync (fn <$> x) + fmap fn (AsyncIOPlumbing x) = mapPlumbing x (fmap (fmap fn)) instance Applicative AsyncIO where pure = AsyncIOSuccess @@ -125,9 +144,11 @@ instance Monad AsyncIO where (>>=) :: forall a b. AsyncIO a -> (a -> AsyncIO b) -> AsyncIO b (>>=) (AsyncIOSuccess x) fn = fn x (>>=) (AsyncIOFailure x) _ = AsyncIOFailure x - (>>=) (AsyncIOIO x) fn = AsyncIOPlumbing $ either AsyncIOFailure fn <$> try x + (>>=) (AsyncIOIO x) fn = AsyncIOPlumbing $ \maskingState cancellationToken -> do + -- TODO masking and cancellation + either AsyncIOFailure fn <$> try x (>>=) (AsyncIOAsync x) fn = bindAsync x fn - (>>=) (AsyncIOPlumbing x) fn = AsyncIOPlumbing $ (>>= fn) <$> x + (>>=) (AsyncIOPlumbing x) fn = mapPlumbing x (fmap (>>= fn)) instance MonadIO AsyncIO where liftIO = AsyncIOIO @@ -141,39 +162,48 @@ instance MonadCatch AsyncIO where catch x@(AsyncIOFailure ex) handler = maybe x handler (fromException ex) catch (AsyncIOIO x) handler = AsyncIOIO (try x) >>= handleEither handler catch (AsyncIOAsync x) handler = bindAsyncCatch x (handleEither handler) - catch (AsyncIOPlumbing x) handler = AsyncIOPlumbing $ (`catch` handler) <$> x + catch (AsyncIOPlumbing x) handler = mapPlumbing x (fmap (`catch` handler)) handleEither :: Exception e => (e -> AsyncIO a) -> Either SomeException a -> AsyncIO a handleEither handler (Left ex) = maybe (AsyncIOFailure ex) handler (fromException ex) handleEither _ (Right r) = pure r +mapPlumbing :: (MaskingState -> CancellationToken -> IO (AsyncIO a)) -> (IO (AsyncIO a) -> IO (AsyncIO b)) -> AsyncIO b +mapPlumbing plumbing fn = AsyncIOPlumbing $ \maskingState cancellationToken -> fn (plumbing maskingState cancellationToken) + bindAsync :: forall a b. Async a -> (a -> AsyncIO b) -> AsyncIO b bindAsync x fn = bindAsyncCatch x (either AsyncIOFailure fn) bindAsyncCatch :: forall a b. Async a -> (Either SomeException a -> AsyncIO b) -> AsyncIO b -bindAsyncCatch x fn = AsyncIOPlumbing $ newAsyncVar >>= bindAsync' +bindAsyncCatch x fn = AsyncIOPlumbing $ \maskingState cancellationToken -> do + var <- newAsyncVar + disposableMVar <- newEmptyMVar + go maskingState cancellationToken var disposableMVar where - bindAsync' resultVar = do - withResult x resultVar step - pure $ await resultVar - step :: (Either SomeException b -> IO ()) -> Either SomeException a -> IO () - step put = putAsyncIOResult put . fn - -withResult :: Async a -> AsyncVar b -> ((Either SomeException b -> IO ()) -> Either SomeException a -> IO ()) -> IO () -withResult x var fn = onResult_ x (failAsyncVar var) (fn (putAsyncVarEither var)) - -putAsyncIOResult :: (Either SomeException a -> IO ()) -> AsyncIO a -> IO () -putAsyncIOResult put (AsyncIOSuccess x) = put (Right x) -putAsyncIOResult put (AsyncIOFailure x) = put (Left x) -putAsyncIOResult put (AsyncIOIO x) = try x >>= put -putAsyncIOResult put (AsyncIOAsync x) = onResult_ x (put . Left) put -putAsyncIOResult put (AsyncIOPlumbing x) = x >>= putAsyncIOResult put + go maskingState cancellationToken var disposableMVar = do + disposable <- onResult x (failAsyncVar_ var) $ \x -> do + (putAsyncIOResult . fn) x + -- TODO update mvar and dispose when completed + putMVar disposableMVar disposable + pure $ awaitUnlessCancellationRequested cancellationToken var + where + put = putAsyncVarEither var + putAsyncIOResult :: AsyncIO b -> IO () + putAsyncIOResult (AsyncIOSuccess x) = put (Right x) + putAsyncIOResult (AsyncIOFailure x) = put (Left x) + putAsyncIOResult (AsyncIOIO x) = try x >>= put + putAsyncIOResult (AsyncIOAsync x) = onResult_ x (put . Left) put + putAsyncIOResult (AsyncIOPlumbing x) = x maskingState cancellationToken >>= putAsyncIOResult -- | Run the synchronous part of an `AsyncIO` and then return an `Async` that can be used to wait for completion of the synchronous part. async :: AsyncIO r -> AsyncIO (Async r) -async = fmap successfulAsync +async (AsyncIOSuccess x) = pure $ successfulAsync x +async (AsyncIOFailure x) = pure $ failedAsync x +async (AsyncIOIO x) = liftIO $ either failedAsync successfulAsync <$> try x +async (AsyncIOAsync x) = pure x -- TODO caching +async (AsyncIOPlumbing x) = mapPlumbing x (fmap async) await :: IsAsync r a => a -> AsyncIO r await = AsyncIOAsync . toAsync @@ -185,7 +215,8 @@ runAsyncIO (AsyncIOFailure x) = throwIO x runAsyncIO (AsyncIOIO x) = x runAsyncIO (AsyncIOAsync x) = wait x runAsyncIO (AsyncIOPlumbing x) = do - x >>= runAsyncIO + maskingState <- getMaskingState + withCancellationToken $ x maskingState >=> runAsyncIO awaitResult :: AsyncIO (Async r) -> AsyncIO r awaitResult = (await =<<) @@ -211,62 +242,127 @@ mapAsync fn = async . fmap fn . await -- | The default implementation for a `Async` that can be fulfilled later. newtype AsyncVar r = AsyncVar (MVar (AsyncVarState r)) data AsyncVarState r - = AsyncVarCompleted (Either SomeException r) + = AsyncVarCompleted (Either SomeException r) (IO ()) | AsyncVarOpen (HM.HashMap Unique (Either SomeException r -> IO (), SomeException -> IO ())) instance IsAsync r (AsyncVar r) where + wait x = do + mvar <- newEmptyMVar + onResult_ x (void . tryPutMVar mvar . Left) (resultCallback mvar) + readMVar mvar >>= either throwIO pure + where + resultCallback :: MVar (Either SomeException r) -> Either SomeException r -> IO () + resultCallback mvar result = do + success <- tryPutMVar mvar result + unless success $ fail "Callback was called multiple times" + peekAsync :: AsyncVar r -> IO (Maybe (Either SomeException r)) peekAsync (AsyncVar mvar) = readMVar mvar >>= pure . \case - AsyncVarCompleted x -> Just x + AsyncVarCompleted x _ -> Just x AsyncVarOpen _ -> Nothing - onResult :: AsyncVar r -> (SomeException -> IO ()) -> (Either SomeException r -> IO ()) -> IO Disposable + onResult :: AsyncVar r -> (SomeException -> IO ()) -> (Either SomeException r -> IO ()) -> IO CallbackDisposable onResult (AsyncVar mvar) callbackExceptionHandler callback = modifyMVar mvar $ \case AsyncVarOpen callbacks -> do key <- newUnique - pure (AsyncVarOpen (HM.insert key (callback, callbackExceptionHandler) callbacks), removeHandler key) - x@(AsyncVarCompleted value) -> (x, noDisposable) <$ callback value + pure (AsyncVarOpen (HM.insert key (callback, callbackExceptionHandler) callbacks), callbackDisposable key) + x@(AsyncVarCompleted value _) -> (x, noCallbackDisposable) <$ callback value `catch` callbackExceptionHandler where - removeHandler :: Unique -> Disposable - removeHandler key = synchronousDisposable $ modifyMVar_ mvar $ pure . \case - x@(AsyncVarCompleted _) -> x - AsyncVarOpen x -> AsyncVarOpen $ HM.delete key x + callbackDisposable :: Unique -> CallbackDisposable + callbackDisposable key = CallbackDisposable removeHandler removeHandlerEventually + where + removeHandler = do + waitForCallbacks <- modifyMVar mvar $ pure . \case + x@(AsyncVarCompleted _ waitForCallbacks) -> (x, waitForCallbacks) + AsyncVarOpen x -> (AsyncVarOpen (HM.delete key x), pure ()) + -- Dispose should only return after the callback can't be called any longer + -- If the callbacks are already being dispatched, wait for them to complete to keep the guarantee + waitForCallbacks + + removeHandlerEventually = + modifyMVar_ mvar $ pure . \case + x@(AsyncVarCompleted _ _) -> x + AsyncVarOpen x -> AsyncVarOpen $ HM.delete key x + + onResult_ x y = void . onResult x y + +tryPutAsyncVarEither :: forall a m. MonadIO m => AsyncVar a -> Either SomeException a -> m Bool +tryPutAsyncVarEither (AsyncVar mvar) value = liftIO $ do + action <- modifyMVar mvar $ \case + x@(AsyncVarCompleted _ waitForCallbacks) -> pure (x, False <$ waitForCallbacks) + AsyncVarOpen callbacksMap -> do + callbacksCompletedMVar <- newEmptyMVar + let waitForCallbacks = readMVar callbacksCompletedMVar + callbacks = HM.elems callbacksMap + pure (AsyncVarCompleted value waitForCallbacks, fireCallbacks callbacks callbacksCompletedMVar) + + action + + where + fireCallbacks :: [(Either SomeException a -> IO (), SomeException -> IO ())] -> MVar () -> IO Bool + fireCallbacks callbacks callbacksCompletedMVar = do + forM_ callbacks $ \(callback, callbackExceptionHandler) -> + callback value `catch` callbackExceptionHandler + putMVar callbacksCompletedMVar () + pure True newAsyncVar :: MonadIO m => m (AsyncVar r) newAsyncVar = liftIO $ AsyncVar <$> newMVar (AsyncVarOpen HM.empty) + putAsyncVar :: MonadIO m => AsyncVar a -> a -> m () putAsyncVar var = putAsyncVarEither var . Right -failAsyncVar :: MonadIO m => AsyncVar a -> SomeException -> m () -failAsyncVar var = putAsyncVarEither var . Left +tryPutAsyncVar :: MonadIO m => AsyncVar a -> a -> m Bool +tryPutAsyncVar var = tryPutAsyncVarEither var . Right + +tryPutAsyncVar_ :: MonadIO m => AsyncVar a -> a -> m () +tryPutAsyncVar_ var = void . tryPutAsyncVar var + +failAsyncVar :: MonadIO m => AsyncVar a -> SomeException -> m Bool +failAsyncVar var = tryPutAsyncVarEither var . Left + +failAsyncVar_ :: MonadIO m => AsyncVar a -> SomeException -> m () +failAsyncVar_ var = void . failAsyncVar var putAsyncVarEither :: MonadIO m => AsyncVar a -> Either SomeException a -> m () -putAsyncVarEither (AsyncVar mvar) value = liftIO $ do - mask $ \restore -> do - takeMVar mvar >>= \case - x@(AsyncVarCompleted _) -> do - putMVar mvar x - fail "An AsyncVar can only be fulfilled once" - AsyncVarOpen callbacksMap -> do - let callbacks = HM.elems callbacksMap - -- NOTE disposing a callback while it is called is a deadlock - forM_ callbacks $ \(callback, callbackExceptionHandler) -> - restore (callback value) `catch` callbackExceptionHandler - putMVar mvar (AsyncVarCompleted value) +putAsyncVarEither avar value = liftIO $ do + success <- tryPutAsyncVarEither avar value + unless success $ fail "An AsyncVar can only be fulfilled once" + +tryPutAsyncVarEither_ :: MonadIO m => AsyncVar a -> Either SomeException a -> m () +tryPutAsyncVarEither_ var = void . tryPutAsyncVarEither var + + +-- * Awaiting multiple asyncs + +awaitEither :: (IsAsync ra a , IsAsync rb b) => a -> b -> AsyncIO (Either ra rb) +awaitEither x y = AsyncIOPlumbing $ \_ _ -> AsyncIOAsync <$> awaitEitherPlumbing x y + +awaitEitherPlumbing :: (IsAsync ra a , IsAsync rb b) => a -> b -> IO (Async (Either ra rb)) +awaitEitherPlumbing x y = do + var <- newAsyncVar + d1 <- onResult x (failAsyncVar_ var) (tryPutAsyncVarEither_ var . fmap Left) + d2 <- onResult y (failAsyncVar_ var) (tryPutAsyncVarEither_ var . fmap Right) + -- The resulting async is kept in memory by 'x' or 'y' until one of them completes. + onResult_ var (const (pure ())) (const (disposeCallbackEventually d1 *> disposeCallbackEventually d2)) + pure $ toAsync var -- * Cancellation newtype CancellationToken = CancellationToken (AsyncVar Void) +instance IsAsync Void CancellationToken where + toAsync (CancellationToken var) = toAsync var + newCancellationToken :: IO CancellationToken newCancellationToken = CancellationToken <$> newAsyncVar cancel :: Exception e => CancellationToken -> e -> IO () -cancel (CancellationToken var) = failAsyncVar var . toException +cancel (CancellationToken var) = failAsyncVar_ var . toException isCancellationRequested :: CancellationToken -> IO Bool isCancellationRequested (CancellationToken var) = isJust <$> peekAsync var @@ -280,6 +376,10 @@ throwIfCancellationRequested (CancellationToken var) = Just (Left ex) -> throwIO ex _ -> pure () +awaitUnlessCancellationRequested :: IsAsync a b => CancellationToken -> b -> AsyncIO a +awaitUnlessCancellationRequested cancellationToken = fmap (either absurd id) . awaitEither cancellationToken . toAsync + + withCancellationToken :: (CancellationToken -> IO a) -> IO a withCancellationToken action = do cancellationToken <- newCancellationToken @@ -289,7 +389,8 @@ withCancellationToken action = do void $ forkIOWithUnmask $ \threadUnmask -> do putMVar resultMVar =<< try (threadUnmask (action cancellationToken)) - either throwIO pure =<< (unmask (takeMVar resultMVar) `catchAll` (\ex -> cancel cancellationToken ex >> takeMVar resultMVar)) + -- TODO test if it is better to run readMVar recursively or to keep it uninterruptible + either throwIO pure =<< (unmask (readMVar resultMVar) `catchAll` (\ex -> cancel cancellationToken ex >> readMVar resultMVar)) -- * Disposable @@ -300,21 +401,22 @@ class IsDisposable a where -- | Dispose a resource. dispose :: a -> AsyncIO () + -- | Dispose a resource in the IO monad. + disposeIO :: a -> IO () + toDisposable :: a -> Disposable toDisposable = mkDisposable . dispose --- | Dispose a resource in the IO monad. -disposeIO :: IsDisposable a => a -> IO () -disposeIO = runAsyncIO . dispose - instance IsDisposable a => IsDisposable (Maybe a) where dispose = mapM_ dispose + disposeIO = mapM_ disposeIO newtype Disposable = Disposable (AsyncIO ()) instance IsDisposable Disposable where dispose (Disposable fn) = fn + disposeIO = runAsyncIO . dispose toDisposable = id instance Semigroup Disposable where @@ -333,3 +435,21 @@ synchronousDisposable = mkDisposable . liftIO noDisposable :: Disposable noDisposable = mempty + + + +data CallbackDisposable = CallbackDisposable (IO ()) (IO ()) + +instance IsDisposable CallbackDisposable where + dispose = liftIO . disposeCallback + disposeIO = disposeCallback + toDisposable = Disposable . dispose + +disposeCallback :: CallbackDisposable -> IO () +disposeCallback (CallbackDisposable f _) = f + +disposeCallbackEventually :: CallbackDisposable -> IO () +disposeCallbackEventually (CallbackDisposable _ e) = e + +noCallbackDisposable :: CallbackDisposable +noCallbackDisposable = CallbackDisposable mempty mempty diff --git a/test/Quasar/AsyncSpec.hs b/test/Quasar/AsyncSpec.hs index ace2f4990d5160f04c87184c0e7dbe809615187b..55c4684bd9807260425803029dd86c90e7d56cb8 100644 --- a/test/Quasar/AsyncSpec.hs +++ b/test/Quasar/AsyncSpec.hs @@ -80,3 +80,9 @@ spec = parallel $ do -- Use bind to create an AsyncIOPlumbing, which is the interesting case that uses `uninterruptibleMask` when run await never >>= pure result `shouldBe` Nothing + + describe "CancellationToken" $ do + it "can be waited upon" $ do + result <- timeout 100000 $ withCancellationToken wait + result `shouldBe` Nothing -- `wait` re-throws the exception +