From 319cc296a57f61c4a9aa8e636486a4478db88032 Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Tue, 30 Nov 2021 22:18:06 +0100
Subject: [PATCH] Implement MonadAwait for STM (as an experiment)

---
 src/Quasar/Awaitable.hs | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)

diff --git a/src/Quasar/Awaitable.hs b/src/Quasar/Awaitable.hs
index f1b9853..ff7ce9c 100644
--- a/src/Quasar/Awaitable.hs
+++ b/src/Quasar/Awaitable.hs
@@ -63,7 +63,7 @@ import GHC.IO (unsafeDupablePerformIO)
 import Quasar.Prelude
 
 
-class (MonadCatch m, MonadFail m, MonadPlus m, MonadFix m) => MonadAwait m where
+class (MonadCatch m, MonadPlus m, MonadFix m) => MonadAwait m where
   -- | Wait until an awaitable is completed and then return it's value (or throw an exception).
   await :: IsAwaitable r a => a -> m r
 
@@ -88,6 +88,15 @@ instance MonadAwait IO where
         \BlockedIndefinitelyOnSTM -> throwM BlockedIndefinitelyOnAwait
   unsafeAwaitSTM = atomically
 
+-- | Experimental instance for `STM`. Using `await` in STM circumvents awaitable caching mechanics, so this only
+-- exists as a test to estimate the usefulness of caching awaitables against the usefulness of awaiting in STM.
+instance MonadAwait STM where
+  await awaitable =
+    runQueryT id (runAwaitable awaitable)
+      `catch`
+        \BlockedIndefinitelyOnSTM -> throwM BlockedIndefinitelyOnAwait
+  unsafeAwaitSTM = id
+
 instance MonadAwait m => MonadAwait (ReaderT a m) where
   await = lift . await
   unsafeAwaitSTM = lift . unsafeAwaitSTM
@@ -231,7 +240,7 @@ awaitableFromSTM :: forall m a. MonadIO m => STM a -> m (Awaitable a)
 awaitableFromSTM transaction = cacheAwaitableUnlessPrimitive (unsafeAwaitSTM transaction :: Awaitable a)
 
 
-instance {-# OVERLAPS #-} (MonadCatch m, MonadFail m, MonadPlus m, MonadFix m) => MonadAwait (ReaderT (QueryFn m) m) where
+instance {-# OVERLAPS #-} (MonadCatch m, MonadPlus m, MonadFix m) => MonadAwait (ReaderT (QueryFn m) m) where
   await = runAwaitable
   unsafeAwaitSTM transaction = do
     QueryFn querySTMFn <- ask
-- 
GitLab