module Quasar.Awaitable ( -- * MonadAwaitable MonadAwait(..), awaitResult, peekAwaitable, -- * Awaitable IsAwaitable(toAwaitable), Awaitable, successfulAwaitable, failedAwaitable, completedAwaitable, awaitableFromSTM, cacheAwaitable, -- * Awaitable helpers afix, afix_, awaitSuccessOrFailure, -- ** Awaiting multiple awaitables awaitAny, awaitAny2, awaitEither, -- * AsyncVar AsyncVar, -- ** Manage `AsyncVar`s in IO newAsyncVar, putAsyncVarEither, putAsyncVar, putAsyncVar_, failAsyncVar, failAsyncVar_, putAsyncVarEither_, -- ** Manage `AsyncVar`s in STM newAsyncVarSTM, putAsyncVarEitherSTM, putAsyncVarSTM, putAsyncVarSTM_, failAsyncVarSTM, failAsyncVarSTM_, putAsyncVarEitherSTM_, readAsyncVarSTM, tryReadAsyncVarSTM, ) where import Control.Applicative (empty) import Control.Concurrent.STM import Control.Exception (BlockedIndefinitelyOnSTM(..)) import Control.Monad.Catch import Control.Monad.Reader import Control.Monad.Writer (WriterT) import Control.Monad.State (StateT) import Control.Monad.RWS (RWST) import Control.Monad.Trans.Maybe import Data.List.NonEmpty (NonEmpty(..), nonEmpty) import Data.Foldable (toList) import Data.Sequence import GHC.IO (unsafeDupablePerformIO) import Quasar.Prelude class (MonadCatch m, MonadPlus m, MonadFix m) => MonadAwait m where -- | Wait until an awaitable is completed and then return it's value (or throw an exception). await :: IsAwaitable r a => a -> m r -- | Await an `STM` transaction. The STM transaction must always return the same result and should not have visible -- side effects. -- -- Use `retry` to signal that the awaitable is not yet completed and use `throwM`/`throwSTM` to signal a failed -- awaitable. unsafeAwaitSTM :: STM a -> m a data BlockedIndefinitelyOnAwait = BlockedIndefinitelyOnAwait deriving stock Show instance Exception BlockedIndefinitelyOnAwait where displayException BlockedIndefinitelyOnAwait = "Thread blocked indefinitely in an 'await' operation" instance MonadAwait IO where await awaitable = liftIO do runQueryT atomically (runAwaitable awaitable) `catch` \BlockedIndefinitelyOnSTM -> throwM BlockedIndefinitelyOnAwait unsafeAwaitSTM = atomically -- | Experimental instance for `STM`. Using `await` in STM circumvents awaitable caching mechanics, so this only -- exists as a test to estimate the usefulness of caching awaitables against the usefulness of awaiting in STM. instance MonadAwait STM where await awaitable = runQueryT id (runAwaitable awaitable) `catch` \BlockedIndefinitelyOnSTM -> throwM BlockedIndefinitelyOnAwait unsafeAwaitSTM = id instance MonadAwait m => MonadAwait (ReaderT a m) where await = lift . await unsafeAwaitSTM = lift . unsafeAwaitSTM instance (MonadAwait m, Monoid a) => MonadAwait (WriterT a m) where await = lift . await unsafeAwaitSTM = lift . unsafeAwaitSTM instance MonadAwait m => MonadAwait (StateT a m) where await = lift . await unsafeAwaitSTM = lift . unsafeAwaitSTM instance (MonadAwait m, Monoid w) => MonadAwait (RWST r w s m) where await = lift . await unsafeAwaitSTM = lift . unsafeAwaitSTM instance MonadAwait m => MonadAwait (MaybeT m) where await = lift . await unsafeAwaitSTM = lift . unsafeAwaitSTM awaitResult :: (IsAwaitable r a, MonadAwait m) => m a -> m r awaitResult = (await =<<) -- | Returns the result (in a `Just`) when the awaitable is completed, throws an `Exception` when the awaitable is -- failed and returns `Nothing` otherwise. peekAwaitable :: (IsAwaitable r a, MonadIO m) => a -> m (Maybe r) peekAwaitable awaitable = liftIO $ runMaybeT $ runQueryT queryFn (runAwaitable awaitable) where queryFn :: STM a -> MaybeT IO a queryFn transaction = MaybeT $ atomically $ (Just <$> transaction) `orElse` pure Nothing class IsAwaitable r a | a -> r where toAwaitable :: a -> Awaitable r toAwaitable = Awaitable -- | Run the awaitable. When interacting with an awaitable you usually want to use `await` instead. `runAwaitable` is -- used to manually implement an instance of `IsAwaitable`. -- -- The implementation of `async` calls `runAwaitable` in most monads, so the implementation of `runAwaitable` must -- not call `async` without deconstructing first. runAwaitable :: (MonadAwait m) => a -> m r runAwaitable self = runAwaitable (toAwaitable self) cacheAwaitableUnlessPrimitive :: MonadIO m => a -> m (Awaitable r) cacheAwaitableUnlessPrimitive self = cacheAwaitableUnlessPrimitive (toAwaitable self) {-# MINIMAL toAwaitable | (runAwaitable, cacheAwaitableUnlessPrimitive) #-} cacheAwaitable :: MonadIO m => Awaitable r -> m (Awaitable r) cacheAwaitable = cacheAwaitableUnlessPrimitive data Awaitable r = forall a. IsAwaitable r a => Awaitable a instance IsAwaitable r (Awaitable r) where runAwaitable (Awaitable x) = runAwaitable x cacheAwaitableUnlessPrimitive (Awaitable x) = cacheAwaitableUnlessPrimitive x toAwaitable = id instance MonadAwait Awaitable where await = toAwaitable unsafeAwaitSTM transaction = mkMonadicAwaitable $ unsafeAwaitSTM transaction instance Functor Awaitable where fmap fn (Awaitable x) = mkMonadicAwaitable $ fn <$> runAwaitable x instance Applicative Awaitable where pure = successfulAwaitable liftA2 fn (Awaitable fx) (Awaitable fy) = mkMonadicAwaitable $ liftA2 fn (runAwaitable fx) (runAwaitable fy) instance Monad Awaitable where (Awaitable fx) >>= fn = mkMonadicAwaitable $ runAwaitable fx >>= runAwaitable . fn instance Semigroup r => Semigroup (Awaitable r) where x <> y = liftA2 (<>) x y instance Monoid r => Monoid (Awaitable r) where mempty = pure mempty instance MonadThrow Awaitable where throwM = failedAwaitable . toException instance MonadCatch Awaitable where catch awaitable handler = mkMonadicAwaitable do runAwaitable awaitable `catch` \ex -> runAwaitable (handler ex) instance MonadFail Awaitable where fail = throwM . userError instance Alternative Awaitable where x <|> y = x `catchAll` const y empty = failedAwaitable $ toException $ userError "empty" instance MonadPlus Awaitable instance MonadFix Awaitable where mfix fn = mkMonadicAwaitable $ mfix \x -> runAwaitable (fn x) newtype MonadicAwaitable r = MonadicAwaitable (forall m. MonadAwait m => m r) instance IsAwaitable r (MonadicAwaitable r) where runAwaitable (MonadicAwaitable x) = x cacheAwaitableUnlessPrimitive = cacheAwaitableDefaultImplementation mkMonadicAwaitable :: MonadAwait m => (forall f. (MonadAwait f) => f r) -> m r mkMonadicAwaitable fn = await $ MonadicAwaitable fn newtype CompletedAwaitable r = CompletedAwaitable (Either SomeException r) instance IsAwaitable r (CompletedAwaitable r) where runAwaitable (CompletedAwaitable x) = either throwM pure x cacheAwaitableUnlessPrimitive = pure . toAwaitable completedAwaitable :: Either SomeException r -> Awaitable r completedAwaitable result = toAwaitable $ CompletedAwaitable result -- | Alias for `pure`. successfulAwaitable :: r -> Awaitable r successfulAwaitable = completedAwaitable . Right failedAwaitable :: SomeException -> Awaitable r failedAwaitable = completedAwaitable . Left -- | Create an awaitable from an `STM` transaction. -- -- The first value or exception returned from the STM transaction will be cached and returned. The STM transacton -- should not have visible side effects. -- -- Use `retry` to signal that the awaitable is not yet completed and `throwM`/`throwSTM` to set the awaitable to failed. awaitableFromSTM :: forall m a. MonadIO m => STM a -> m (Awaitable a) awaitableFromSTM transaction = cacheAwaitableUnlessPrimitive (unsafeAwaitSTM transaction :: Awaitable a) instance {-# OVERLAPS #-} (MonadCatch m, MonadPlus m, MonadFix m) => MonadAwait (ReaderT (QueryFn m) m) where await = runAwaitable unsafeAwaitSTM transaction = do QueryFn querySTMFn <- ask lift $ querySTMFn transaction newtype QueryFn m = QueryFn (forall a. STM a -> m a) runQueryT :: forall m a. (forall b. STM b -> m b) -> ReaderT (QueryFn m) m a -> m a runQueryT queryFn action = runReaderT action (QueryFn queryFn) -- TODO add guard to only allow one thread to step cache newtype CachedAwaitable r = CachedAwaitable (TVar (AwaitableStepM r)) cacheAwaitableDefaultImplementation :: (IsAwaitable r a, MonadIO m) => a -> m (Awaitable r) cacheAwaitableDefaultImplementation awaitable = toAwaitable . CachedAwaitable <$> liftIO (newTVarIO (runAwaitable awaitable)) instance IsAwaitable r (CachedAwaitable r) where runAwaitable :: forall m. MonadAwait m => CachedAwaitable r -> m r runAwaitable (CachedAwaitable tvar) = go where go :: m r go = unsafeAwaitSTM stepCacheTransaction >>= \case AwaitableCompleted result -> pure result AwaitableFailed ex -> throwM ex -- Cached operation is not yet completed _ -> go stepCacheTransaction :: STM (AwaitableStepM r) stepCacheTransaction = do readTVar tvar >>= \case -- Cache needs to be stepped AwaitableStep query fn -> do -- Run the next "querySTM" query requested by the cached operation -- The query might `retry`, which is ok here nextStep <- fn <$> try query -- In case of an incomplete query the caller (/ the monad `m`) can decide what to do (e.g. retry for -- `awaitIO`, abort for `peekAwaitable`) -- Query was successful. Update cache and exit query writeTVar tvar nextStep pure nextStep -- Cache was already completed result -> pure result cacheAwaitableUnlessPrimitive = pure . toAwaitable data AwaitableStepM a = AwaitableCompleted a | AwaitableFailed SomeException | forall b. AwaitableStep (STM b) (Either SomeException b -> AwaitableStepM a) instance Functor AwaitableStepM where fmap fn (AwaitableCompleted x) = AwaitableCompleted (fn x) fmap _ (AwaitableFailed ex) = AwaitableFailed ex fmap fn (AwaitableStep query next) = AwaitableStep query (fmap fn <$> next) instance Applicative AwaitableStepM where pure = AwaitableCompleted liftA2 fn fx fy = fx >>= \x -> fn x <$> fy instance Monad AwaitableStepM where (AwaitableCompleted x) >>= fn = fn x (AwaitableFailed ex) >>= _ = AwaitableFailed ex (AwaitableStep query next) >>= fn = AwaitableStep query (next >=> fn) instance MonadAwait AwaitableStepM where await = runAwaitable unsafeAwaitSTM query = AwaitableStep query (either AwaitableFailed AwaitableCompleted) instance MonadThrow AwaitableStepM where throwM = AwaitableFailed . toException instance MonadCatch AwaitableStepM where catch result@(AwaitableCompleted _) _ = result catch result@(AwaitableFailed ex) handler = maybe result handler $ fromException ex catch (AwaitableStep query next) handler = AwaitableStep query (\x -> next x `catch` handler) instance MonadFail AwaitableStepM where fail = throwM . userError instance Alternative AwaitableStepM where x <|> y = x `catchAll` const y empty = throwM $ toException $ userError "empty" instance MonadPlus AwaitableStepM instance MonadFix AwaitableStepM where mfix :: forall a. (a -> AwaitableStepM a) -> AwaitableStepM a mfix fn = AwaitableStep newEmptyTMVar applyFix where applyFix :: Either SomeException (TMVar a) -> AwaitableStepM a applyFix (Right var) = do result <- fn $ unsafeDupablePerformIO do atomically (readTMVar var) `catch` \BlockedIndefinitelyOnSTM -> throwIO FixAwaitException storeResult var result applyFix (Left _) = unreachableCodePathM -- `newEmptyTMVar` should never fail storeResult :: TMVar a -> a -> AwaitableStepM a storeResult var x = AwaitableStep (void $ tryPutTMVar var x) (\_ -> pure x) data FixAwaitException = FixAwaitException deriving stock (Eq, Show) deriving anyclass Exception -- ** AsyncVar -- | The default implementation for an `Awaitable` that can be fulfilled later. newtype AsyncVar r = AsyncVar (TMVar (Either SomeException r)) instance IsAwaitable r (AsyncVar r) where runAwaitable (AsyncVar var) = unsafeAwaitSTM $ either throwM pure =<< readTMVar var -- An AsyncVar is a primitive awaitable, so caching is not necessary cacheAwaitableUnlessPrimitive = pure . toAwaitable newAsyncVarSTM :: STM (AsyncVar r) newAsyncVarSTM = AsyncVar <$> newEmptyTMVar newAsyncVar :: MonadIO m => m (AsyncVar r) newAsyncVar = liftIO $ AsyncVar <$> newEmptyTMVarIO putAsyncVarEither :: forall a m. MonadIO m => AsyncVar a -> Either SomeException a -> m Bool putAsyncVarEither var = liftIO . atomically . putAsyncVarEitherSTM var putAsyncVarEitherSTM :: AsyncVar a -> Either SomeException a -> STM Bool putAsyncVarEitherSTM (AsyncVar var) = tryPutTMVar var -- | Get the value of an `AsyncVar` in `STM`. Will retry until the AsyncVar is fulfilled. readAsyncVarSTM :: AsyncVar a -> STM a readAsyncVarSTM (AsyncVar var) = either throwM pure =<< readTMVar var tryReadAsyncVarSTM :: forall a. AsyncVar a -> STM (Maybe a) tryReadAsyncVarSTM (AsyncVar var) = mapM (either throwM pure) =<< tryReadTMVar var putAsyncVar :: MonadIO m => AsyncVar a -> a -> m Bool putAsyncVar var = putAsyncVarEither var . Right putAsyncVarSTM :: AsyncVar a -> a -> STM Bool putAsyncVarSTM var = putAsyncVarEitherSTM var . Right putAsyncVar_ :: MonadIO m => AsyncVar a -> a -> m () putAsyncVar_ var = void . putAsyncVar var putAsyncVarSTM_ :: AsyncVar a -> a -> STM () putAsyncVarSTM_ var = void . putAsyncVarSTM var failAsyncVar :: (Exception e, MonadIO m) => AsyncVar a -> e -> m Bool failAsyncVar var = putAsyncVarEither var . Left . toException failAsyncVarSTM :: Exception e => AsyncVar a -> e -> STM Bool failAsyncVarSTM var = putAsyncVarEitherSTM var . Left . toException failAsyncVar_ :: (Exception e, MonadIO m) => AsyncVar a -> e -> m () failAsyncVar_ var = void . failAsyncVar var failAsyncVarSTM_ :: Exception e => AsyncVar a -> e -> STM () failAsyncVarSTM_ var = void . failAsyncVarSTM var putAsyncVarEither_ :: MonadIO m => AsyncVar a -> Either SomeException a -> m () putAsyncVarEither_ var = void . putAsyncVarEither var putAsyncVarEitherSTM_ :: AsyncVar a -> Either SomeException a -> STM () putAsyncVarEitherSTM_ var = void . putAsyncVarEitherSTM var -- * Utility functions -- | Await success or failure of another awaitable, then return `()`. awaitSuccessOrFailure :: (IsAwaitable r a, MonadAwait m) => a -> m () awaitSuccessOrFailure = await . fireAndForget . toAwaitable where fireAndForget :: MonadCatch m => m r -> m () fireAndForget x = void x `catchAll` const (pure ()) afix :: (MonadIO m, MonadCatch m) => (Awaitable a -> m a) -> m a afix action = do var <- newAsyncVar catchAll do result <- action (toAwaitable var) putAsyncVar_ var result pure result \ex -> do failAsyncVar_ var ex throwM ex afix_ :: (MonadIO m, MonadCatch m) => (Awaitable a -> m a) -> m () afix_ = void . afix -- ** Awaiting multiple awaitables -- | Completes as soon as either awaitable completes. awaitEither :: MonadAwait m => Awaitable ra -> Awaitable rb -> m (Either ra rb) awaitEither x y = mkMonadicAwaitable $ stepBoth (runAwaitable x) (runAwaitable y) where stepBoth :: MonadAwait m => AwaitableStepM ra -> AwaitableStepM rb -> m (Either ra rb) stepBoth (AwaitableCompleted resultX) _ = pure $ Left resultX stepBoth (AwaitableFailed ex) _ = throwM ex stepBoth _ (AwaitableCompleted resultY) = pure $ Right resultY stepBoth _ (AwaitableFailed ex) = throwM ex stepBoth stepX@(AwaitableStep transactionX nextX) stepY@(AwaitableStep transactionY nextY) = do unsafeAwaitSTM (eitherSTM (try transactionX) (try transactionY)) >>= \case Left resultX -> stepBoth (nextX resultX) stepY Right resultY -> stepBoth stepX (nextY resultY) -- | Helper for `awaitEither` eitherSTM :: STM a -> STM b -> STM (Either a b) eitherSTM x y = fmap Left x `orElse` fmap Right y -- Completes as soon as any awaitable in the list is completed and then returns the left-most completed result -- (or exception). awaitAny :: MonadAwait m => [Awaitable r] -> m r awaitAny xs = mkMonadicAwaitable $ stepAll Empty Empty $ runAwaitable <$> fromList xs where stepAll :: MonadAwait m => Seq (STM (Seq (AwaitableStepM r))) -> Seq (AwaitableStepM r) -> Seq (AwaitableStepM r) -> m r stepAll _ _ (AwaitableCompleted result :<| _) = pure result stepAll _ _ (AwaitableFailed ex :<| _) = throwM ex stepAll acc prevSteps (step@(AwaitableStep transaction next) :<| steps) = stepAll do acc |> ((\result -> (prevSteps |> next result) <> steps) <$> try transaction) do prevSteps |> step steps stepAll acc _ Empty = do newAwaitableSteps <- unsafeAwaitSTM $ maybe unreachableCodePathM anySTM $ nonEmpty (toList acc) stepAll Empty Empty newAwaitableSteps -- | Helper for `awaitAny` anySTM :: NonEmpty (STM a) -> STM a anySTM (x :| xs) = x `orElse` maybe retry anySTM (nonEmpty xs) -- | Like `awaitAny` with two awaitables. awaitAny2 :: MonadAwait m => Awaitable r -> Awaitable r -> m r awaitAny2 x y = awaitAny [toAwaitable x, toAwaitable y]