Skip to content
Snippets Groups Projects
Commit 79e0acd3 authored by Jens Nolte's avatar Jens Nolte
Browse files

Implement AsyncIO cancellation

parent 5d86d239
No related branches found
No related tags found
No related merge requests found
......@@ -23,11 +23,13 @@ module Quasar.Core (
-- * Disposable
IsDisposable(..),
disposeIO,
Disposable,
mkDisposable,
synchronousDisposable,
noDisposable,
-- * Cancellation
withCancellationToken,
) where
import Control.Concurrent (forkIOWithUnmask)
......@@ -35,6 +37,7 @@ import Control.Exception (MaskingState(..), getMaskingState)
import Control.Monad.Catch
import Data.HashMap.Strict qualified as HM
import Data.Maybe (isJust)
import Data.Void (absurd)
import Quasar.Prelude
-- * Async
......@@ -42,22 +45,17 @@ import Quasar.Prelude
class IsAsync r a | a -> r where
-- | Wait until the promise is settled and return the result.
wait :: a -> IO r
wait x = do
mvar <- newEmptyMVar
onResult_ x (void . tryPutMVar mvar . Left) (resultCallback mvar)
readMVar mvar >>= either throwIO pure
where
resultCallback :: MVar (Either SomeException r) -> Either SomeException r -> IO ()
resultCallback mvar result = do
success <- tryPutMVar mvar result
unless success $ fail "Callback was called multiple times"
wait = wait . toAsync
peekAsync :: a -> IO (Maybe (Either SomeException r))
peekAsync = peekAsync . toAsync
-- | Register a callback, that will be called once the promise is settled.
-- If the promise is already settled, the callback will be called immediately instead.
--
-- The returned `Disposable` can be used to deregister the callback.
--
-- 'onResult' should not throw.
onResult
:: a
-- ^ async
......@@ -65,18 +63,21 @@ class IsAsync r a | a -> r where
-- ^ callback exception handler
-> (Either SomeException r -> IO ())
-- ^ callback
-> IO Disposable
-> IO CallbackDisposable
onResult x ceh c = onResult (toAsync x) ceh c
onResult_
:: a
-> (SomeException -> IO ())
-> (Either SomeException r -> IO ())
-> IO ()
onResult_ x y = void . onResult x y
onResult_ x ceh c = onResult_ (toAsync x) ceh c
toAsync :: a -> Async r
toAsync = SomeAsync
{-# MINIMAL toAsync | (wait, peekAsync, onResult, onResult_) #-}
data Async r = forall a. IsAsync r a => SomeAsync a
......@@ -86,12 +87,18 @@ instance IsAsync r (Async r) where
onResult_ (SomeAsync x) y = onResult_ x y
peekAsync (SomeAsync x) = peekAsync x
instance Functor Async where
fmap fn = toAsync . MappedAsync fn
newtype CompletedAsync r = CompletedAsync (Either SomeException r)
instance IsAsync r (CompletedAsync r) where
wait (CompletedAsync value) = either throwIO pure value
onResult (CompletedAsync value) callbackExceptionHandler callback = noDisposable <$ (callback value `catch` callbackExceptionHandler)
onResult (CompletedAsync value) callbackExceptionHandler callback =
noCallbackDisposable <$ (callback value `catch` callbackExceptionHandler)
onResult_ (CompletedAsync value) callbackExceptionHandler callback =
callback value `catch` callbackExceptionHandler
peekAsync (CompletedAsync value) = pure $ Just value
completedAsync :: Either SomeException r -> Async r
......@@ -104,6 +111,14 @@ failedAsync :: SomeException -> Async r
failedAsync = completedAsync . Left
data MappedAsync r = forall a. MappedAsync (a -> r) (Async a)
instance IsAsync r (MappedAsync r) where
wait (MappedAsync fn x) = fn <$> wait x
peekAsync (MappedAsync fn x) = fmap fn <<$>> peekAsync x
onResult (MappedAsync fn x) callbackExceptionHandler callback = onResult x callbackExceptionHandler $ callback . fmap fn
onResult_ (MappedAsync fn x) callbackExceptionHandler callback = onResult_ x callbackExceptionHandler $ callback . fmap fn
-- * AsyncIO
data AsyncIO r
......@@ -111,10 +126,14 @@ data AsyncIO r
| AsyncIOFailure SomeException
| AsyncIOIO (IO r)
| AsyncIOAsync (Async r)
| AsyncIOPlumbing (IO (AsyncIO r))
| AsyncIOPlumbing (MaskingState -> CancellationToken -> IO (AsyncIO r))
instance Functor AsyncIO where
fmap fn = (>>= pure . fn)
fmap fn (AsyncIOSuccess x) = AsyncIOSuccess (fn x)
fmap _ (AsyncIOFailure x) = AsyncIOFailure x
fmap fn (AsyncIOIO x) = AsyncIOIO (fn <$> x)
fmap fn (AsyncIOAsync x) = AsyncIOAsync (fn <$> x)
fmap fn (AsyncIOPlumbing x) = mapPlumbing x (fmap (fmap fn))
instance Applicative AsyncIO where
pure = AsyncIOSuccess
......@@ -125,9 +144,11 @@ instance Monad AsyncIO where
(>>=) :: forall a b. AsyncIO a -> (a -> AsyncIO b) -> AsyncIO b
(>>=) (AsyncIOSuccess x) fn = fn x
(>>=) (AsyncIOFailure x) _ = AsyncIOFailure x
(>>=) (AsyncIOIO x) fn = AsyncIOPlumbing $ either AsyncIOFailure fn <$> try x
(>>=) (AsyncIOIO x) fn = AsyncIOPlumbing $ \maskingState cancellationToken -> do
-- TODO masking and cancellation
either AsyncIOFailure fn <$> try x
(>>=) (AsyncIOAsync x) fn = bindAsync x fn
(>>=) (AsyncIOPlumbing x) fn = AsyncIOPlumbing $ (>>= fn) <$> x
(>>=) (AsyncIOPlumbing x) fn = mapPlumbing x (fmap (>>= fn))
instance MonadIO AsyncIO where
liftIO = AsyncIOIO
......@@ -141,39 +162,48 @@ instance MonadCatch AsyncIO where
catch x@(AsyncIOFailure ex) handler = maybe x handler (fromException ex)
catch (AsyncIOIO x) handler = AsyncIOIO (try x) >>= handleEither handler
catch (AsyncIOAsync x) handler = bindAsyncCatch x (handleEither handler)
catch (AsyncIOPlumbing x) handler = AsyncIOPlumbing $ (`catch` handler) <$> x
catch (AsyncIOPlumbing x) handler = mapPlumbing x (fmap (`catch` handler))
handleEither :: Exception e => (e -> AsyncIO a) -> Either SomeException a -> AsyncIO a
handleEither handler (Left ex) = maybe (AsyncIOFailure ex) handler (fromException ex)
handleEither _ (Right r) = pure r
mapPlumbing :: (MaskingState -> CancellationToken -> IO (AsyncIO a)) -> (IO (AsyncIO a) -> IO (AsyncIO b)) -> AsyncIO b
mapPlumbing plumbing fn = AsyncIOPlumbing $ \maskingState cancellationToken -> fn (plumbing maskingState cancellationToken)
bindAsync :: forall a b. Async a -> (a -> AsyncIO b) -> AsyncIO b
bindAsync x fn = bindAsyncCatch x (either AsyncIOFailure fn)
bindAsyncCatch :: forall a b. Async a -> (Either SomeException a -> AsyncIO b) -> AsyncIO b
bindAsyncCatch x fn = AsyncIOPlumbing $ newAsyncVar >>= bindAsync'
bindAsyncCatch x fn = AsyncIOPlumbing $ \maskingState cancellationToken -> do
var <- newAsyncVar
disposableMVar <- newEmptyMVar
go maskingState cancellationToken var disposableMVar
where
bindAsync' resultVar = do
withResult x resultVar step
pure $ await resultVar
step :: (Either SomeException b -> IO ()) -> Either SomeException a -> IO ()
step put = putAsyncIOResult put . fn
withResult :: Async a -> AsyncVar b -> ((Either SomeException b -> IO ()) -> Either SomeException a -> IO ()) -> IO ()
withResult x var fn = onResult_ x (failAsyncVar var) (fn (putAsyncVarEither var))
putAsyncIOResult :: (Either SomeException a -> IO ()) -> AsyncIO a -> IO ()
putAsyncIOResult put (AsyncIOSuccess x) = put (Right x)
putAsyncIOResult put (AsyncIOFailure x) = put (Left x)
putAsyncIOResult put (AsyncIOIO x) = try x >>= put
putAsyncIOResult put (AsyncIOAsync x) = onResult_ x (put . Left) put
putAsyncIOResult put (AsyncIOPlumbing x) = x >>= putAsyncIOResult put
go maskingState cancellationToken var disposableMVar = do
disposable <- onResult x (failAsyncVar_ var) $ \x -> do
(putAsyncIOResult . fn) x
-- TODO update mvar and dispose when completed
putMVar disposableMVar disposable
pure $ awaitUnlessCancellationRequested cancellationToken var
where
put = putAsyncVarEither var
putAsyncIOResult :: AsyncIO b -> IO ()
putAsyncIOResult (AsyncIOSuccess x) = put (Right x)
putAsyncIOResult (AsyncIOFailure x) = put (Left x)
putAsyncIOResult (AsyncIOIO x) = try x >>= put
putAsyncIOResult (AsyncIOAsync x) = onResult_ x (put . Left) put
putAsyncIOResult (AsyncIOPlumbing x) = x maskingState cancellationToken >>= putAsyncIOResult
-- | Run the synchronous part of an `AsyncIO` and then return an `Async` that can be used to wait for completion of the synchronous part.
async :: AsyncIO r -> AsyncIO (Async r)
async = fmap successfulAsync
async (AsyncIOSuccess x) = pure $ successfulAsync x
async (AsyncIOFailure x) = pure $ failedAsync x
async (AsyncIOIO x) = liftIO $ either failedAsync successfulAsync <$> try x
async (AsyncIOAsync x) = pure x -- TODO caching
async (AsyncIOPlumbing x) = mapPlumbing x (fmap async)
await :: IsAsync r a => a -> AsyncIO r
await = AsyncIOAsync . toAsync
......@@ -185,7 +215,8 @@ runAsyncIO (AsyncIOFailure x) = throwIO x
runAsyncIO (AsyncIOIO x) = x
runAsyncIO (AsyncIOAsync x) = wait x
runAsyncIO (AsyncIOPlumbing x) = do
x >>= runAsyncIO
maskingState <- getMaskingState
withCancellationToken $ x maskingState >=> runAsyncIO
awaitResult :: AsyncIO (Async r) -> AsyncIO r
awaitResult = (await =<<)
......@@ -211,62 +242,127 @@ mapAsync fn = async . fmap fn . await
-- | The default implementation for a `Async` that can be fulfilled later.
newtype AsyncVar r = AsyncVar (MVar (AsyncVarState r))
data AsyncVarState r
= AsyncVarCompleted (Either SomeException r)
= AsyncVarCompleted (Either SomeException r) (IO ())
| AsyncVarOpen (HM.HashMap Unique (Either SomeException r -> IO (), SomeException -> IO ()))
instance IsAsync r (AsyncVar r) where
wait x = do
mvar <- newEmptyMVar
onResult_ x (void . tryPutMVar mvar . Left) (resultCallback mvar)
readMVar mvar >>= either throwIO pure
where
resultCallback :: MVar (Either SomeException r) -> Either SomeException r -> IO ()
resultCallback mvar result = do
success <- tryPutMVar mvar result
unless success $ fail "Callback was called multiple times"
peekAsync :: AsyncVar r -> IO (Maybe (Either SomeException r))
peekAsync (AsyncVar mvar) = readMVar mvar >>= pure . \case
AsyncVarCompleted x -> Just x
AsyncVarCompleted x _ -> Just x
AsyncVarOpen _ -> Nothing
onResult :: AsyncVar r -> (SomeException -> IO ()) -> (Either SomeException r -> IO ()) -> IO Disposable
onResult :: AsyncVar r -> (SomeException -> IO ()) -> (Either SomeException r -> IO ()) -> IO CallbackDisposable
onResult (AsyncVar mvar) callbackExceptionHandler callback =
modifyMVar mvar $ \case
AsyncVarOpen callbacks -> do
key <- newUnique
pure (AsyncVarOpen (HM.insert key (callback, callbackExceptionHandler) callbacks), removeHandler key)
x@(AsyncVarCompleted value) -> (x, noDisposable) <$ callback value
pure (AsyncVarOpen (HM.insert key (callback, callbackExceptionHandler) callbacks), callbackDisposable key)
x@(AsyncVarCompleted value _) -> (x, noCallbackDisposable) <$ callback value `catch` callbackExceptionHandler
where
removeHandler :: Unique -> Disposable
removeHandler key = synchronousDisposable $ modifyMVar_ mvar $ pure . \case
x@(AsyncVarCompleted _) -> x
AsyncVarOpen x -> AsyncVarOpen $ HM.delete key x
callbackDisposable :: Unique -> CallbackDisposable
callbackDisposable key = CallbackDisposable removeHandler removeHandlerEventually
where
removeHandler = do
waitForCallbacks <- modifyMVar mvar $ pure . \case
x@(AsyncVarCompleted _ waitForCallbacks) -> (x, waitForCallbacks)
AsyncVarOpen x -> (AsyncVarOpen (HM.delete key x), pure ())
-- Dispose should only return after the callback can't be called any longer
-- If the callbacks are already being dispatched, wait for them to complete to keep the guarantee
waitForCallbacks
removeHandlerEventually =
modifyMVar_ mvar $ pure . \case
x@(AsyncVarCompleted _ _) -> x
AsyncVarOpen x -> AsyncVarOpen $ HM.delete key x
onResult_ x y = void . onResult x y
tryPutAsyncVarEither :: forall a m. MonadIO m => AsyncVar a -> Either SomeException a -> m Bool
tryPutAsyncVarEither (AsyncVar mvar) value = liftIO $ do
action <- modifyMVar mvar $ \case
x@(AsyncVarCompleted _ waitForCallbacks) -> pure (x, False <$ waitForCallbacks)
AsyncVarOpen callbacksMap -> do
callbacksCompletedMVar <- newEmptyMVar
let waitForCallbacks = readMVar callbacksCompletedMVar
callbacks = HM.elems callbacksMap
pure (AsyncVarCompleted value waitForCallbacks, fireCallbacks callbacks callbacksCompletedMVar)
action
where
fireCallbacks :: [(Either SomeException a -> IO (), SomeException -> IO ())] -> MVar () -> IO Bool
fireCallbacks callbacks callbacksCompletedMVar = do
forM_ callbacks $ \(callback, callbackExceptionHandler) ->
callback value `catch` callbackExceptionHandler
putMVar callbacksCompletedMVar ()
pure True
newAsyncVar :: MonadIO m => m (AsyncVar r)
newAsyncVar = liftIO $ AsyncVar <$> newMVar (AsyncVarOpen HM.empty)
putAsyncVar :: MonadIO m => AsyncVar a -> a -> m ()
putAsyncVar var = putAsyncVarEither var . Right
failAsyncVar :: MonadIO m => AsyncVar a -> SomeException -> m ()
failAsyncVar var = putAsyncVarEither var . Left
tryPutAsyncVar :: MonadIO m => AsyncVar a -> a -> m Bool
tryPutAsyncVar var = tryPutAsyncVarEither var . Right
tryPutAsyncVar_ :: MonadIO m => AsyncVar a -> a -> m ()
tryPutAsyncVar_ var = void . tryPutAsyncVar var
failAsyncVar :: MonadIO m => AsyncVar a -> SomeException -> m Bool
failAsyncVar var = tryPutAsyncVarEither var . Left
failAsyncVar_ :: MonadIO m => AsyncVar a -> SomeException -> m ()
failAsyncVar_ var = void . failAsyncVar var
putAsyncVarEither :: MonadIO m => AsyncVar a -> Either SomeException a -> m ()
putAsyncVarEither (AsyncVar mvar) value = liftIO $ do
mask $ \restore -> do
takeMVar mvar >>= \case
x@(AsyncVarCompleted _) -> do
putMVar mvar x
fail "An AsyncVar can only be fulfilled once"
AsyncVarOpen callbacksMap -> do
let callbacks = HM.elems callbacksMap
-- NOTE disposing a callback while it is called is a deadlock
forM_ callbacks $ \(callback, callbackExceptionHandler) ->
restore (callback value) `catch` callbackExceptionHandler
putMVar mvar (AsyncVarCompleted value)
putAsyncVarEither avar value = liftIO $ do
success <- tryPutAsyncVarEither avar value
unless success $ fail "An AsyncVar can only be fulfilled once"
tryPutAsyncVarEither_ :: MonadIO m => AsyncVar a -> Either SomeException a -> m ()
tryPutAsyncVarEither_ var = void . tryPutAsyncVarEither var
-- * Awaiting multiple asyncs
awaitEither :: (IsAsync ra a , IsAsync rb b) => a -> b -> AsyncIO (Either ra rb)
awaitEither x y = AsyncIOPlumbing $ \_ _ -> AsyncIOAsync <$> awaitEitherPlumbing x y
awaitEitherPlumbing :: (IsAsync ra a , IsAsync rb b) => a -> b -> IO (Async (Either ra rb))
awaitEitherPlumbing x y = do
var <- newAsyncVar
d1 <- onResult x (failAsyncVar_ var) (tryPutAsyncVarEither_ var . fmap Left)
d2 <- onResult y (failAsyncVar_ var) (tryPutAsyncVarEither_ var . fmap Right)
-- The resulting async is kept in memory by 'x' or 'y' until one of them completes.
onResult_ var (const (pure ())) (const (disposeCallbackEventually d1 *> disposeCallbackEventually d2))
pure $ toAsync var
-- * Cancellation
newtype CancellationToken = CancellationToken (AsyncVar Void)
instance IsAsync Void CancellationToken where
toAsync (CancellationToken var) = toAsync var
newCancellationToken :: IO CancellationToken
newCancellationToken = CancellationToken <$> newAsyncVar
cancel :: Exception e => CancellationToken -> e -> IO ()
cancel (CancellationToken var) = failAsyncVar var . toException
cancel (CancellationToken var) = failAsyncVar_ var . toException
isCancellationRequested :: CancellationToken -> IO Bool
isCancellationRequested (CancellationToken var) = isJust <$> peekAsync var
......@@ -280,6 +376,10 @@ throwIfCancellationRequested (CancellationToken var) =
Just (Left ex) -> throwIO ex
_ -> pure ()
awaitUnlessCancellationRequested :: IsAsync a b => CancellationToken -> b -> AsyncIO a
awaitUnlessCancellationRequested cancellationToken = fmap (either absurd id) . awaitEither cancellationToken . toAsync
withCancellationToken :: (CancellationToken -> IO a) -> IO a
withCancellationToken action = do
cancellationToken <- newCancellationToken
......@@ -289,7 +389,8 @@ withCancellationToken action = do
void $ forkIOWithUnmask $ \threadUnmask -> do
putMVar resultMVar =<< try (threadUnmask (action cancellationToken))
either throwIO pure =<< (unmask (takeMVar resultMVar) `catchAll` (\ex -> cancel cancellationToken ex >> takeMVar resultMVar))
-- TODO test if it is better to run readMVar recursively or to keep it uninterruptible
either throwIO pure =<< (unmask (readMVar resultMVar) `catchAll` (\ex -> cancel cancellationToken ex >> readMVar resultMVar))
-- * Disposable
......@@ -300,21 +401,22 @@ class IsDisposable a where
-- | Dispose a resource.
dispose :: a -> AsyncIO ()
-- | Dispose a resource in the IO monad.
disposeIO :: a -> IO ()
toDisposable :: a -> Disposable
toDisposable = mkDisposable . dispose
-- | Dispose a resource in the IO monad.
disposeIO :: IsDisposable a => a -> IO ()
disposeIO = runAsyncIO . dispose
instance IsDisposable a => IsDisposable (Maybe a) where
dispose = mapM_ dispose
disposeIO = mapM_ disposeIO
newtype Disposable = Disposable (AsyncIO ())
instance IsDisposable Disposable where
dispose (Disposable fn) = fn
disposeIO = runAsyncIO . dispose
toDisposable = id
instance Semigroup Disposable where
......@@ -333,3 +435,21 @@ synchronousDisposable = mkDisposable . liftIO
noDisposable :: Disposable
noDisposable = mempty
data CallbackDisposable = CallbackDisposable (IO ()) (IO ())
instance IsDisposable CallbackDisposable where
dispose = liftIO . disposeCallback
disposeIO = disposeCallback
toDisposable = Disposable . dispose
disposeCallback :: CallbackDisposable -> IO ()
disposeCallback (CallbackDisposable f _) = f
disposeCallbackEventually :: CallbackDisposable -> IO ()
disposeCallbackEventually (CallbackDisposable _ e) = e
noCallbackDisposable :: CallbackDisposable
noCallbackDisposable = CallbackDisposable mempty mempty
......@@ -80,3 +80,9 @@ spec = parallel $ do
-- Use bind to create an AsyncIOPlumbing, which is the interesting case that uses `uninterruptibleMask` when run
await never >>= pure
result `shouldBe` Nothing
describe "CancellationToken" $ do
it "can be waited upon" $ do
result <- timeout 100000 $ withCancellationToken wait
result `shouldBe` Nothing -- `wait` re-throws the exception
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment