From 221d5c22aa358207463442e1ebbbdf429c25f56b Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Sun, 29 Aug 2021 00:36:32 +0200
Subject: [PATCH] Refine awaitable helper functions

---
 src/Quasar/Awaitable.hs  | 68 +++++++++++++++++++++++++++++++---------
 src/Quasar/Disposable.hs | 10 +++---
 2 files changed, 58 insertions(+), 20 deletions(-)

diff --git a/src/Quasar/Awaitable.hs b/src/Quasar/Awaitable.hs
index c2aee3d..1c380cc 100644
--- a/src/Quasar/Awaitable.hs
+++ b/src/Quasar/Awaitable.hs
@@ -8,24 +8,37 @@ module Quasar.Awaitable (
   successfulAwaitable,
   failedAwaitable,
   completedAwaitable,
-  simpleAwaitable,
+  awaitSTM,
+  unsafeAwaitSTM,
 
-  -- * Awaiting multiple awaitables
+  -- * Awaitable helpers
+
+  awaitSuccessOrFailure,
+
+  -- ** Awaiting multiple awaitables
   awaitEither,
   awaitAny,
   awaitAny2,
 
   -- * AsyncVar
   AsyncVar,
+
+  -- ** Manage `AsyncVar`s in IO
   newAsyncVar,
-  newAsyncVarSTM,
   putAsyncVarEither,
-  putAsyncVarEitherSTM,
   putAsyncVar,
   putAsyncVar_,
   failAsyncVar,
   failAsyncVar_,
   putAsyncVarEither_,
+
+  -- ** Manage `AsyncVar`s in STM
+  newAsyncVarSTM,
+  putAsyncVarEitherSTM,
+  putAsyncVarSTM,
+  putAsyncVarSTM_,
+  failAsyncVarSTM,
+  failAsyncVarSTM_,
   putAsyncVarEitherSTM_,
 
   -- * Implementation helpers
@@ -138,12 +151,21 @@ successfulAwaitable = completedAwaitable . Right
 failedAwaitable :: SomeException -> Awaitable r
 failedAwaitable = completedAwaitable . Left
 
+-- | Create an awaitable from an `STM` transaction.
+--
+-- The first value or exception returned from the STM transaction will be cached and returned. The STM transacton
+-- should not have visible side effects.
+--
+-- Use `retry` to signal that the awaitable is not yet completed and `throwM`/`throwSTM` to set the awaitable to failed.
+awaitSTM :: MonadIO m => STM a -> m (Awaitable a)
+awaitSTM = cacheAwaitable . unsafeAwaitSTM
+
 -- | Create an awaitable from an `STM` transaction. The STM transaction must always return the same result and should
 -- not have visible side effects.
 --
 -- Use `retry` to signal that the awaitable is not yet completed and `throwM`/`throwSTM` to set the awaitable to failed.
-simpleAwaitable :: STM a -> Awaitable a
-simpleAwaitable query = fnAwaitable $ querySTM query
+unsafeAwaitSTM :: STM a -> Awaitable a
+unsafeAwaitSTM query = fnAwaitable $ querySTM query
 
 
 class MonadCatch m => MonadQuerySTM m where
@@ -151,11 +173,6 @@ class MonadCatch m => MonadQuerySTM m where
   querySTM :: (forall a. STM a -> m a)
 
 
--- | Run an `STM` transaction. Use `retry` to signal that no value is available (yet).
-tryQuerySTM :: MonadQuerySTM m => STM a -> m (Either SomeException a)
-tryQuerySTM transaction = querySTM (try transaction)
-
-
 instance MonadCatch m => MonadQuerySTM (ReaderT (QueryFn m) m) where
   querySTM query = do
     QueryFn querySTMFn <- ask
@@ -260,15 +277,27 @@ putAsyncVarEitherSTM (AsyncVar var) = tryPutTMVar var
 putAsyncVar :: MonadIO m => AsyncVar a -> a -> m Bool
 putAsyncVar var = putAsyncVarEither var . Right
 
+putAsyncVarSTM :: AsyncVar a -> a -> STM Bool
+putAsyncVarSTM var = putAsyncVarEitherSTM var . Right
+
 putAsyncVar_ :: MonadIO m => AsyncVar a -> a -> m ()
 putAsyncVar_ var = void . putAsyncVar var
 
-failAsyncVar :: MonadIO m => AsyncVar a -> SomeException -> m Bool
-failAsyncVar var = putAsyncVarEither var . Left
+putAsyncVarSTM_ :: AsyncVar a -> a -> STM ()
+putAsyncVarSTM_ var = void . putAsyncVarSTM var
+
+failAsyncVar :: (Exception e, MonadIO m) => AsyncVar a -> e -> m Bool
+failAsyncVar var = putAsyncVarEither var . Left . toException
 
-failAsyncVar_ :: MonadIO m => AsyncVar a -> SomeException -> m ()
+failAsyncVarSTM :: Exception e => AsyncVar a -> e -> STM Bool
+failAsyncVarSTM var = putAsyncVarEitherSTM var . Left . toException
+
+failAsyncVar_ :: (Exception e, MonadIO m) => AsyncVar a -> e -> m ()
 failAsyncVar_ var = void . failAsyncVar var
 
+failAsyncVarSTM_ :: Exception e => AsyncVar a -> e -> STM ()
+failAsyncVarSTM_ var = void . failAsyncVarSTM var
+
 putAsyncVarEither_ :: MonadIO m => AsyncVar a -> Either SomeException a -> m ()
 putAsyncVarEither_ var = void . putAsyncVarEither var
 
@@ -277,7 +306,16 @@ putAsyncVarEitherSTM_ var = void . putAsyncVarEitherSTM var
 
 
 
--- * Awaiting multiple asyncs
+-- * Utility functions
+
+-- | Create an awaitable that is completed successfully when the input awaitable is successful or failed.
+awaitSuccessOrFailure :: IsAwaitable r a => a -> Awaitable ()
+awaitSuccessOrFailure = fireAndForget . toAwaitable
+  where
+    fireAndForget :: MonadCatch m => m r -> m ()
+    fireAndForget x = void x `catchAll` const (pure ())
+
+-- ** Awaiting multiple awaitables
 
 awaitEither :: (IsAwaitable ra a, IsAwaitable rb b) => a -> b -> Awaitable (Either ra rb)
 awaitEither x y = fnAwaitable $ stepBoth (runAwaitable x) (runAwaitable y)
diff --git a/src/Quasar/Disposable.hs b/src/Quasar/Disposable.hs
index 7378ee5..bccd3b3 100644
--- a/src/Quasar/Disposable.hs
+++ b/src/Quasar/Disposable.hs
@@ -100,7 +100,7 @@ instance IsDisposable FnDisposable where
 instance IsAwaitable () FnDisposable where
   toAwaitable :: FnDisposable -> Awaitable ()
   toAwaitable (FnDisposable var) =
-    join $ simpleAwaitable do
+    join $ unsafeAwaitSTM do
       state <- readTMVar var
       case state of
         -- Wait until disposing has been started
@@ -164,7 +164,7 @@ newtype ResourceManagerEntry = ResourceManagerEntry (TMVar (Awaitable (), Dispos
 
 instance IsAwaitable () ResourceManagerEntry where
   toAwaitable (ResourceManagerEntry var) = do
-    varContents <- simpleAwaitable $ tryReadTMVar var
+    varContents <- unsafeAwaitSTM $ tryReadTMVar var
     case varContents of
       -- If the var is empty the Entry has already been disposed
       Nothing -> pure ()
@@ -222,7 +222,7 @@ instance IsDisposable ResourceManager where
         pure $ isDisposed resourceManager
 
   isDisposed resourceManager =
-    simpleAwaitable do
+    unsafeAwaitSTM do
       (throwM =<< readTMVar (exceptionVar resourceManager))
         `orElse`
           ((\disposed -> unless disposed retry) =<< readTVar (disposedVar resourceManager))
@@ -263,11 +263,11 @@ collectGarbage resourceManager = go
     go = do
       snapshot <- atomically $ readTVar entriesVar'
 
-      let listChanged = simpleAwaitable do
+      let listChanged = unsafeAwaitSTM do
             newLength <- Seq.length <$> readTVar entriesVar'
             when (newLength == Seq.length snapshot) retry
 
-          isDisposing = simpleAwaitable do
+          isDisposing = unsafeAwaitSTM do
             disposing <- readTVar (disposingVar resourceManager)
             unless disposing retry
 
-- 
GitLab