From 1a98d93553c0f995382a93af6feae06336f2f1fb Mon Sep 17 00:00:00 2001 From: Jens Nolte <git@queezle.net> Date: Wed, 1 Sep 2021 00:11:06 +0200 Subject: [PATCH] Remove MonadQuerySTM by quantifying MonadAwait actions instead --- src/Quasar/Awaitable.hs | 67 ++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/src/Quasar/Awaitable.hs b/src/Quasar/Awaitable.hs index dac3d2b..755aa98 100644 --- a/src/Quasar/Awaitable.hs +++ b/src/Quasar/Awaitable.hs @@ -10,7 +10,7 @@ module Quasar.Awaitable ( successfulAwaitable, failedAwaitable, completedAwaitable, - awaitSTM, + awaitableFromSTM, unsafeAwaitSTM, -- * Awaitable helpers @@ -45,7 +45,6 @@ module Quasar.Awaitable ( readAsyncVarSTM, -- * Implementation helpers - MonadQuerySTM(querySTM), cacheAwaitableDefaultImplementation, ) where @@ -63,11 +62,20 @@ import Quasar.Prelude class (MonadCatch m, MonadPlus m) => MonadAwait m where await :: IsAwaitable r a => a -> m r + -- | Await an `STM` transaction. The STM transaction must always return the same result and should not have visible + -- side effects. + -- + -- Use `retry` to signal that the awaitable is not yet completed and use `throwM`/`throwSTM` to signal a failed + -- awaitable. + unsafeAwaitSTM :: STM a -> m a + instance MonadAwait IO where await awaitable = liftIO $ runQueryT atomically (runAwaitable awaitable) + unsafeAwaitSTM = atomically instance MonadAwait m => MonadAwait (ReaderT a m) where await = lift . await + unsafeAwaitSTM = lift . unsafeAwaitSTM awaitResult :: (IsAwaitable r a, MonadAwait m) => m a -> m r @@ -85,7 +93,7 @@ peekAwaitable awaitable = liftIO $ runMaybeT $ runQueryT queryFn (runAwaitable a class IsAwaitable r a | a -> r where - runAwaitable :: (MonadQuerySTM m) => a -> m r + runAwaitable :: (MonadAwait m) => a -> m r runAwaitable self = runAwaitable (toAwaitable self) cacheAwaitable :: MonadIO m => a -> m (Awaitable r) @@ -106,6 +114,7 @@ instance IsAwaitable r (Awaitable r) where instance MonadAwait Awaitable where await = toAwaitable + unsafeAwaitSTM transaction = mkMonadicAwaitable $ unsafeAwaitSTM transaction instance Functor Awaitable where fmap fn (Awaitable x) = mkMonadicAwaitable $ fn <$> runAwaitable x @@ -142,13 +151,13 @@ instance MonadPlus Awaitable -newtype MonadicAwaitable r = MonadicAwaitable (forall m. (MonadQuerySTM m) => m r) +newtype MonadicAwaitable r = MonadicAwaitable (forall m. MonadAwait m => m r) instance IsAwaitable r (MonadicAwaitable r) where runAwaitable (MonadicAwaitable x) = x cacheAwaitable = cacheAwaitableDefaultImplementation -mkMonadicAwaitable :: MonadAwait m => (forall f. (MonadQuerySTM f) => f r) -> m r +mkMonadicAwaitable :: MonadAwait m => (forall f. (MonadAwait f) => f r) -> m r mkMonadicAwaitable fn = await $ MonadicAwaitable fn @@ -175,26 +184,15 @@ failedAwaitable = completedAwaitable . Left -- should not have visible side effects. -- -- Use `retry` to signal that the awaitable is not yet completed and `throwM`/`throwSTM` to set the awaitable to failed. -awaitSTM :: MonadIO m => STM a -> m (Awaitable a) -awaitSTM = cacheAwaitable . unsafeAwaitSTM - --- | Create an awaitable from an `STM` transaction. The STM transaction must always return the same result and should --- not have visible side effects. --- --- Use `retry` to signal that the awaitable is not yet completed and `throwM`/`throwSTM` to set the awaitable to failed. -unsafeAwaitSTM :: STM a -> Awaitable a -unsafeAwaitSTM query = mkMonadicAwaitable $ querySTM query - +awaitableFromSTM :: forall m a. MonadIO m => STM a -> m (Awaitable a) +awaitableFromSTM transaction = cacheAwaitable (unsafeAwaitSTM transaction :: Awaitable a) -class MonadCatch m => MonadQuerySTM m where - -- | Run an `STM` transaction. Use `retry` to signal that no value is available (yet). - querySTM :: (forall a. STM a -> m a) - -instance MonadCatch m => MonadQuerySTM (ReaderT (QueryFn m) m) where - querySTM query = do +instance {-# OVERLAPS #-} (MonadCatch m, MonadPlus m) => MonadAwait (ReaderT (QueryFn m) m) where + await = runAwaitable + unsafeAwaitSTM transaction = do QueryFn querySTMFn <- ask - lift $ querySTMFn query + lift $ querySTMFn transaction newtype QueryFn m = QueryFn (forall a. STM a -> m a) @@ -208,11 +206,11 @@ cacheAwaitableDefaultImplementation :: (IsAwaitable r a, MonadIO m) => a -> m (A cacheAwaitableDefaultImplementation awaitable = toAwaitable . CachedAwaitable <$> liftIO (newTVarIO (runAwaitable awaitable)) instance IsAwaitable r (CachedAwaitable r) where - runAwaitable :: forall m. (MonadQuerySTM m) => CachedAwaitable r -> m r + runAwaitable :: forall m. MonadAwait m => CachedAwaitable r -> m r runAwaitable (CachedAwaitable tvar) = go where go :: m r - go = querySTM stepCacheTransaction >>= \case + go = unsafeAwaitSTM stepCacheTransaction >>= \case AwaitableCompleted result -> pure result AwaitableFailed ex -> throwM ex -- Cached operation is not yet completed @@ -255,8 +253,9 @@ instance Monad AwaitableStepM where (AwaitableFailed ex) >>= _ = AwaitableFailed ex (AwaitableStep query next) >>= fn = AwaitableStep query (next >=> fn) -instance MonadQuerySTM AwaitableStepM where - querySTM query = AwaitableStep query (either AwaitableFailed AwaitableCompleted) +instance MonadAwait AwaitableStepM where + await = runAwaitable + unsafeAwaitSTM query = AwaitableStep query (either AwaitableFailed AwaitableCompleted) instance MonadThrow AwaitableStepM where throwM = AwaitableFailed . toException @@ -266,6 +265,12 @@ instance MonadCatch AwaitableStepM where catch result@(AwaitableFailed ex) handler = maybe result handler $ fromException ex catch (AwaitableStep query next) handler = AwaitableStep query (\x -> next x `catch` handler) +instance Alternative AwaitableStepM where + x <|> y = x `catchAll` const y + empty = throwM $ toException $ userError "empty" + +instance MonadPlus AwaitableStepM where + -- ** AsyncVar @@ -273,7 +278,7 @@ instance MonadCatch AwaitableStepM where newtype AsyncVar r = AsyncVar (TMVar (Either SomeException r)) instance IsAwaitable r (AsyncVar r) where - runAwaitable (AsyncVar var) = querySTM $ either throwM pure =<< readTMVar var + runAwaitable (AsyncVar var) = unsafeAwaitSTM $ either throwM pure =<< readTMVar var -- An AsyncVar is a primitive awaitable, so caching is not necessary cacheAwaitable = pure . toAwaitable @@ -345,13 +350,13 @@ awaitSuccessOrFailure = await . fireAndForget . toAwaitable awaitEither :: (IsAwaitable ra a, IsAwaitable rb b, MonadAwait m) => a -> b -> m (Either ra rb) awaitEither x y = mkMonadicAwaitable $ stepBoth (runAwaitable x) (runAwaitable y) where - stepBoth :: MonadQuerySTM m => AwaitableStepM ra -> AwaitableStepM rb -> m (Either ra rb) + stepBoth :: MonadAwait m => AwaitableStepM ra -> AwaitableStepM rb -> m (Either ra rb) stepBoth (AwaitableCompleted resultX) _ = pure $ Left resultX stepBoth (AwaitableFailed ex) _ = throwM ex stepBoth _ (AwaitableCompleted resultY) = pure $ Right resultY stepBoth _ (AwaitableFailed ex) = throwM ex stepBoth stepX@(AwaitableStep transactionX nextX) stepY@(AwaitableStep transactionY nextY) = do - querySTM (eitherSTM (try transactionX) (try transactionY)) >>= \case + unsafeAwaitSTM (eitherSTM (try transactionX) (try transactionY)) >>= \case Left resultX -> stepBoth (nextX resultX) stepY Right resultY -> stepBoth stepX (nextY resultY) @@ -360,7 +365,7 @@ awaitAny :: (IsAwaitable r a, MonadAwait m) => NonEmpty a -> m r awaitAny xs = mkMonadicAwaitable $ stepAll Empty Empty $ runAwaitable <$> fromList (toList xs) where stepAll - :: MonadQuerySTM m + :: MonadAwait m => Seq (STM (Seq (AwaitableStepM r))) -> Seq (AwaitableStepM r) -> Seq (AwaitableStepM r) @@ -373,7 +378,7 @@ awaitAny xs = mkMonadicAwaitable $ stepAll Empty Empty $ runAwaitable <$> fromLi do prevSteps |> step steps stepAll acc _ Empty = do - newAwaitableSteps <- querySTM $ maybe impossibleCodePathM anySTM $ nonEmpty (toList acc) + newAwaitableSteps <- unsafeAwaitSTM $ maybe impossibleCodePathM anySTM $ nonEmpty (toList acc) stepAll Empty Empty newAwaitableSteps -- GitLab