From 1a98d93553c0f995382a93af6feae06336f2f1fb Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Wed, 1 Sep 2021 00:11:06 +0200
Subject: [PATCH] Remove MonadQuerySTM by quantifying MonadAwait actions
 instead

---
 src/Quasar/Awaitable.hs | 67 ++++++++++++++++++++++-------------------
 1 file changed, 36 insertions(+), 31 deletions(-)

diff --git a/src/Quasar/Awaitable.hs b/src/Quasar/Awaitable.hs
index dac3d2b..755aa98 100644
--- a/src/Quasar/Awaitable.hs
+++ b/src/Quasar/Awaitable.hs
@@ -10,7 +10,7 @@ module Quasar.Awaitable (
   successfulAwaitable,
   failedAwaitable,
   completedAwaitable,
-  awaitSTM,
+  awaitableFromSTM,
   unsafeAwaitSTM,
 
   -- * Awaitable helpers
@@ -45,7 +45,6 @@ module Quasar.Awaitable (
   readAsyncVarSTM,
 
   -- * Implementation helpers
-  MonadQuerySTM(querySTM),
   cacheAwaitableDefaultImplementation,
 ) where
 
@@ -63,11 +62,20 @@ import Quasar.Prelude
 class (MonadCatch m, MonadPlus m) => MonadAwait m where
   await :: IsAwaitable r a => a -> m r
 
+  -- | Await an `STM` transaction. The STM transaction must always return the same result and should not have visible
+  -- side effects.
+  --
+  -- Use `retry` to signal that the awaitable is not yet completed and use `throwM`/`throwSTM` to signal a failed
+  -- awaitable.
+  unsafeAwaitSTM :: STM a -> m a
+
 instance MonadAwait IO where
   await awaitable = liftIO $ runQueryT atomically (runAwaitable awaitable)
+  unsafeAwaitSTM = atomically
 
 instance MonadAwait m => MonadAwait (ReaderT a m) where
   await = lift . await
+  unsafeAwaitSTM = lift . unsafeAwaitSTM
 
 
 awaitResult :: (IsAwaitable r a, MonadAwait m) => m a -> m r
@@ -85,7 +93,7 @@ peekAwaitable awaitable = liftIO $ runMaybeT $ runQueryT queryFn (runAwaitable a
 
 
 class IsAwaitable r a | a -> r where
-  runAwaitable :: (MonadQuerySTM m) => a -> m r
+  runAwaitable :: (MonadAwait m) => a -> m r
   runAwaitable self = runAwaitable (toAwaitable self)
 
   cacheAwaitable :: MonadIO m => a -> m (Awaitable r)
@@ -106,6 +114,7 @@ instance IsAwaitable r (Awaitable r) where
 
 instance MonadAwait Awaitable where
   await = toAwaitable
+  unsafeAwaitSTM transaction = mkMonadicAwaitable $ unsafeAwaitSTM transaction
 
 instance Functor Awaitable where
   fmap fn (Awaitable x) = mkMonadicAwaitable $ fn <$> runAwaitable x
@@ -142,13 +151,13 @@ instance MonadPlus Awaitable
 
 
 
-newtype MonadicAwaitable r = MonadicAwaitable (forall m. (MonadQuerySTM m) => m r)
+newtype MonadicAwaitable r = MonadicAwaitable (forall m. MonadAwait m => m r)
 
 instance IsAwaitable r (MonadicAwaitable r) where
   runAwaitable (MonadicAwaitable x) = x
   cacheAwaitable = cacheAwaitableDefaultImplementation
 
-mkMonadicAwaitable :: MonadAwait m => (forall f. (MonadQuerySTM f) => f r) -> m r
+mkMonadicAwaitable :: MonadAwait m => (forall f. (MonadAwait f) => f r) -> m r
 mkMonadicAwaitable fn = await $ MonadicAwaitable fn
 
 
@@ -175,26 +184,15 @@ failedAwaitable = completedAwaitable . Left
 -- should not have visible side effects.
 --
 -- Use `retry` to signal that the awaitable is not yet completed and `throwM`/`throwSTM` to set the awaitable to failed.
-awaitSTM :: MonadIO m => STM a -> m (Awaitable a)
-awaitSTM = cacheAwaitable . unsafeAwaitSTM
-
--- | Create an awaitable from an `STM` transaction. The STM transaction must always return the same result and should
--- not have visible side effects.
---
--- Use `retry` to signal that the awaitable is not yet completed and `throwM`/`throwSTM` to set the awaitable to failed.
-unsafeAwaitSTM :: STM a -> Awaitable a
-unsafeAwaitSTM query = mkMonadicAwaitable $ querySTM query
-
+awaitableFromSTM :: forall m a. MonadIO m => STM a -> m (Awaitable a)
+awaitableFromSTM transaction = cacheAwaitable (unsafeAwaitSTM transaction :: Awaitable a)
 
-class MonadCatch m => MonadQuerySTM m where
-  -- | Run an `STM` transaction. Use `retry` to signal that no value is available (yet).
-  querySTM :: (forall a. STM a -> m a)
 
-
-instance MonadCatch m => MonadQuerySTM (ReaderT (QueryFn m) m) where
-  querySTM query = do
+instance {-# OVERLAPS #-} (MonadCatch m, MonadPlus m) => MonadAwait (ReaderT (QueryFn m) m) where
+  await = runAwaitable
+  unsafeAwaitSTM transaction = do
     QueryFn querySTMFn <- ask
-    lift $ querySTMFn query
+    lift $ querySTMFn transaction
 
 newtype QueryFn m = QueryFn (forall a. STM a -> m a)
 
@@ -208,11 +206,11 @@ cacheAwaitableDefaultImplementation :: (IsAwaitable r a, MonadIO m) => a -> m (A
 cacheAwaitableDefaultImplementation awaitable = toAwaitable . CachedAwaitable <$> liftIO (newTVarIO (runAwaitable awaitable))
 
 instance IsAwaitable r (CachedAwaitable r) where
-  runAwaitable :: forall m. (MonadQuerySTM m) => CachedAwaitable r -> m r
+  runAwaitable :: forall m. MonadAwait m => CachedAwaitable r -> m r
   runAwaitable (CachedAwaitable tvar) = go
     where
       go :: m r
-      go = querySTM stepCacheTransaction >>= \case
+      go = unsafeAwaitSTM stepCacheTransaction >>= \case
         AwaitableCompleted result -> pure result
         AwaitableFailed ex -> throwM ex
         -- Cached operation is not yet completed
@@ -255,8 +253,9 @@ instance Monad AwaitableStepM where
   (AwaitableFailed ex) >>= _ = AwaitableFailed ex
   (AwaitableStep query next) >>= fn = AwaitableStep query (next >=> fn)
 
-instance MonadQuerySTM AwaitableStepM where
-  querySTM query = AwaitableStep query (either AwaitableFailed AwaitableCompleted)
+instance MonadAwait AwaitableStepM where
+  await = runAwaitable
+  unsafeAwaitSTM query = AwaitableStep query (either AwaitableFailed AwaitableCompleted)
 
 instance MonadThrow AwaitableStepM where
   throwM = AwaitableFailed . toException
@@ -266,6 +265,12 @@ instance MonadCatch AwaitableStepM where
   catch result@(AwaitableFailed ex) handler = maybe result handler $ fromException ex
   catch (AwaitableStep query next) handler = AwaitableStep query (\x -> next x `catch` handler)
 
+instance Alternative AwaitableStepM where
+  x <|> y = x `catchAll` const y
+  empty = throwM $ toException $ userError "empty"
+
+instance MonadPlus AwaitableStepM where
+
 
 -- ** AsyncVar
 
@@ -273,7 +278,7 @@ instance MonadCatch AwaitableStepM where
 newtype AsyncVar r = AsyncVar (TMVar (Either SomeException r))
 
 instance IsAwaitable r (AsyncVar r) where
-  runAwaitable (AsyncVar var) = querySTM $ either throwM pure =<< readTMVar var
+  runAwaitable (AsyncVar var) = unsafeAwaitSTM $ either throwM pure =<< readTMVar var
   -- An AsyncVar is a primitive awaitable, so caching is not necessary
   cacheAwaitable = pure . toAwaitable
 
@@ -345,13 +350,13 @@ awaitSuccessOrFailure = await . fireAndForget . toAwaitable
 awaitEither :: (IsAwaitable ra a, IsAwaitable rb b, MonadAwait m) => a -> b -> m (Either ra rb)
 awaitEither x y = mkMonadicAwaitable $ stepBoth (runAwaitable x) (runAwaitable y)
   where
-    stepBoth :: MonadQuerySTM m => AwaitableStepM ra -> AwaitableStepM rb -> m (Either ra rb)
+    stepBoth :: MonadAwait m => AwaitableStepM ra -> AwaitableStepM rb -> m (Either ra rb)
     stepBoth (AwaitableCompleted resultX) _ = pure $ Left resultX
     stepBoth (AwaitableFailed ex) _ = throwM ex
     stepBoth _ (AwaitableCompleted resultY) = pure $ Right resultY
     stepBoth _ (AwaitableFailed ex) = throwM ex
     stepBoth stepX@(AwaitableStep transactionX nextX) stepY@(AwaitableStep transactionY nextY) = do
-      querySTM (eitherSTM (try transactionX) (try transactionY)) >>= \case
+      unsafeAwaitSTM (eitherSTM (try transactionX) (try transactionY)) >>= \case
         Left resultX -> stepBoth (nextX resultX) stepY
         Right resultY -> stepBoth stepX (nextY resultY)
 
@@ -360,7 +365,7 @@ awaitAny :: (IsAwaitable r a, MonadAwait m) => NonEmpty a -> m r
 awaitAny xs = mkMonadicAwaitable $ stepAll Empty Empty $ runAwaitable <$> fromList (toList xs)
   where
     stepAll
-      :: MonadQuerySTM m
+      :: MonadAwait m
       => Seq (STM (Seq (AwaitableStepM r)))
       -> Seq (AwaitableStepM r)
       -> Seq (AwaitableStepM r)
@@ -373,7 +378,7 @@ awaitAny xs = mkMonadicAwaitable $ stepAll Empty Empty $ runAwaitable <$> fromLi
         do prevSteps |> step
         steps
     stepAll acc _ Empty = do
-      newAwaitableSteps <- querySTM $ maybe impossibleCodePathM anySTM $ nonEmpty (toList acc)
+      newAwaitableSteps <- unsafeAwaitSTM $ maybe impossibleCodePathM anySTM $ nonEmpty (toList acc)
       stepAll Empty Empty newAwaitableSteps
 
 
-- 
GitLab