From f0fd12ae0d735437349555c660ba25b6837cc52b Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Sat, 18 Dec 2021 22:42:38 +0100
Subject: [PATCH] Remove awaitable caching

This reduces awaitable back to a simple newtype wrapper over the STM
monad. While caching worked well, it was only applicable in very few
situations. Based on the high complexity, the rare use cases and the
negative interaction with STM we decided to remove caching.

Co-authored-by: Jan Beinke <git@janbeinke.com>
---
 src/Quasar/Awaitable.hs       | 287 +++++-----------------------------
 src/Quasar/ResourceManager.hs |   3 +-
 test/Quasar/AwaitableSpec.hs  |  67 --------
 3 files changed, 39 insertions(+), 318 deletions(-)

diff --git a/src/Quasar/Awaitable.hs b/src/Quasar/Awaitable.hs
index ff7ce9c..14992a8 100644
--- a/src/Quasar/Awaitable.hs
+++ b/src/Quasar/Awaitable.hs
@@ -1,7 +1,6 @@
 module Quasar.Awaitable (
   -- * MonadAwaitable
   MonadAwait(..),
-  awaitResult,
   peekAwaitable,
 
   -- * Awaitable
@@ -10,8 +9,6 @@ module Quasar.Awaitable (
   successfulAwaitable,
   failedAwaitable,
   completedAwaitable,
-  awaitableFromSTM,
-  cacheAwaitable,
 
   -- * Awaitable helpers
   afix,
@@ -45,9 +42,12 @@ module Quasar.Awaitable (
   putAsyncVarEitherSTM_,
   readAsyncVarSTM,
   tryReadAsyncVarSTM,
+
+  -- ** Unsafe implementation helpers
+  unsafeSTMToAwaitable,
+  unsafeAwaitSTM,
 ) where
 
-import Control.Applicative (empty)
 import Control.Concurrent.STM
 import Control.Exception (BlockedIndefinitelyOnSTM(..))
 import Control.Monad.Catch
@@ -56,10 +56,6 @@ import Control.Monad.Writer (WriterT)
 import Control.Monad.State (StateT)
 import Control.Monad.RWS (RWST)
 import Control.Monad.Trans.Maybe
-import Data.List.NonEmpty (NonEmpty(..), nonEmpty)
-import Data.Foldable (toList)
-import Data.Sequence
-import GHC.IO (unsafeDupablePerformIO)
 import Quasar.Prelude
 
 
@@ -67,13 +63,6 @@ class (MonadCatch m, MonadPlus m, MonadFix m) => MonadAwait m where
   -- | Wait until an awaitable is completed and then return it's value (or throw an exception).
   await :: IsAwaitable r a => a -> m r
 
-  -- | Await an `STM` transaction. The STM transaction must always return the same result and should not have visible
-  -- side effects.
-  --
-  -- Use `retry` to signal that the awaitable is not yet completed and use `throwM`/`throwSTM` to signal a failed
-  -- awaitable.
-  unsafeAwaitSTM :: STM a -> m a
-
 data BlockedIndefinitelyOnAwait = BlockedIndefinitelyOnAwait
   deriving stock Show
 
@@ -82,98 +71,71 @@ instance Exception BlockedIndefinitelyOnAwait where
 
 
 instance MonadAwait IO where
-  await awaitable = liftIO do
-    runQueryT atomically (runAwaitable awaitable)
+  await (toAwaitable -> Awaitable x) =
+    atomically x
       `catch`
         \BlockedIndefinitelyOnSTM -> throwM BlockedIndefinitelyOnAwait
-  unsafeAwaitSTM = atomically
 
--- | Experimental instance for `STM`. Using `await` in STM circumvents awaitable caching mechanics, so this only
--- exists as a test to estimate the usefulness of caching awaitables against the usefulness of awaiting in STM.
-instance MonadAwait STM where
-  await awaitable =
-    runQueryT id (runAwaitable awaitable)
-      `catch`
-        \BlockedIndefinitelyOnSTM -> throwM BlockedIndefinitelyOnAwait
-  unsafeAwaitSTM = id
+awaitSTM :: Awaitable a -> STM a
+awaitSTM (toAwaitable -> Awaitable x) =
+  x `catch` \BlockedIndefinitelyOnSTM -> throwM BlockedIndefinitelyOnAwait
 
 instance MonadAwait m => MonadAwait (ReaderT a m) where
   await = lift . await
-  unsafeAwaitSTM = lift . unsafeAwaitSTM
 
 instance (MonadAwait m, Monoid a) => MonadAwait (WriterT a m) where
   await = lift . await
-  unsafeAwaitSTM = lift . unsafeAwaitSTM
 
 instance MonadAwait m => MonadAwait (StateT a m) where
   await = lift . await
-  unsafeAwaitSTM = lift . unsafeAwaitSTM
 
 instance (MonadAwait m, Monoid w) => MonadAwait (RWST r w s m) where
   await = lift . await
-  unsafeAwaitSTM = lift . unsafeAwaitSTM
 
 instance MonadAwait m => MonadAwait (MaybeT m) where
   await = lift . await
-  unsafeAwaitSTM = lift . unsafeAwaitSTM
 
 
-awaitResult :: (IsAwaitable r a, MonadAwait m) => m a -> m r
-awaitResult = (await =<<)
-
 
 -- | Returns the result (in a `Just`) when the awaitable is completed, throws an `Exception` when the awaitable is
 -- failed and returns `Nothing` otherwise.
-peekAwaitable :: (IsAwaitable r a, MonadIO m) => a -> m (Maybe r)
-peekAwaitable awaitable = liftIO $ runMaybeT $ runQueryT queryFn (runAwaitable awaitable)
-  where
-    queryFn :: STM a -> MaybeT IO a
-    queryFn transaction = MaybeT $ atomically $ (Just <$> transaction) `orElse` pure Nothing
+peekAwaitable :: MonadIO m => Awaitable r -> m (Maybe r)
+peekAwaitable awaitable = liftIO $ atomically $ (Just <$> awaitSTM awaitable) `orElse` pure Nothing
 
 
 
 class IsAwaitable r a | a -> r where
   toAwaitable :: a -> Awaitable r
-  toAwaitable = Awaitable
 
-  -- | Run the awaitable. When interacting with an awaitable you usually want to use `await` instead. `runAwaitable` is
-  -- used to manually implement an instance of `IsAwaitable`.
-  --
-  -- The implementation of `async` calls `runAwaitable` in most monads, so the implementation of `runAwaitable` must
-  -- not call `async` without deconstructing first.
-  runAwaitable :: (MonadAwait m) => a -> m r
-  runAwaitable self = runAwaitable (toAwaitable self)
 
-  cacheAwaitableUnlessPrimitive :: MonadIO m => a -> m (Awaitable r)
-  cacheAwaitableUnlessPrimitive self = cacheAwaitableUnlessPrimitive (toAwaitable self)
 
-  {-# MINIMAL toAwaitable | (runAwaitable, cacheAwaitableUnlessPrimitive) #-}
 
 
-cacheAwaitable :: MonadIO m => Awaitable r -> m (Awaitable r)
-cacheAwaitable = cacheAwaitableUnlessPrimitive
+unsafeSTMToAwaitable :: STM a -> Awaitable a
+unsafeSTMToAwaitable = Awaitable
+
+unsafeAwaitSTM :: MonadAwait m => STM a -> m a
+unsafeAwaitSTM = await . unsafeSTMToAwaitable
+
 
+newtype Awaitable r = Awaitable (STM r)
+  deriving newtype (
+    Functor,
+    Applicative,
+    Monad,
+    MonadThrow,
+    MonadCatch,
+    MonadFix,
+    Alternative,
+    MonadPlus
+    )
 
-data Awaitable r = forall a. IsAwaitable r a => Awaitable a
 
 instance IsAwaitable r (Awaitable r) where
-  runAwaitable (Awaitable x) = runAwaitable x
-  cacheAwaitableUnlessPrimitive (Awaitable x) = cacheAwaitableUnlessPrimitive x
   toAwaitable = id
 
 instance MonadAwait Awaitable where
   await = toAwaitable
-  unsafeAwaitSTM transaction = mkMonadicAwaitable $ unsafeAwaitSTM transaction
-
-instance Functor Awaitable where
-  fmap fn (Awaitable x) = mkMonadicAwaitable $ fn <$> runAwaitable x
-
-instance Applicative Awaitable where
-  pure = successfulAwaitable
-  liftA2 fn (Awaitable fx) (Awaitable fy) = mkMonadicAwaitable $ liftA2 fn (runAwaitable fx) (runAwaitable fy)
-
-instance Monad Awaitable where
-  (Awaitable fx) >>= fn = mkMonadicAwaitable $ runAwaitable fx >>= runAwaitable . fn
 
 instance Semigroup r => Semigroup (Awaitable r) where
   x <> y = liftA2 (<>) x y
@@ -181,169 +143,22 @@ instance Semigroup r => Semigroup (Awaitable r) where
 instance Monoid r => Monoid (Awaitable r) where
   mempty = pure mempty
 
-instance MonadThrow Awaitable where
-  throwM = failedAwaitable . toException
-
-instance MonadCatch Awaitable where
-  catch awaitable handler = mkMonadicAwaitable do
-    runAwaitable awaitable `catch` \ex -> runAwaitable (handler ex)
-
 instance MonadFail Awaitable where
   fail = throwM . userError
 
 
-instance Alternative Awaitable where
-  x <|> y = x `catchAll` const y
-  empty = failedAwaitable $ toException $ userError "empty"
-
-instance MonadPlus Awaitable
-
-instance MonadFix Awaitable where
-  mfix fn = mkMonadicAwaitable $ mfix \x -> runAwaitable (fn x)
-
-
-
-newtype MonadicAwaitable r = MonadicAwaitable (forall m. MonadAwait m => m r)
-
-instance IsAwaitable r (MonadicAwaitable r) where
-  runAwaitable (MonadicAwaitable x) = x
-  cacheAwaitableUnlessPrimitive = cacheAwaitableDefaultImplementation
-
-mkMonadicAwaitable :: MonadAwait m => (forall f. (MonadAwait f) => f r) -> m r
-mkMonadicAwaitable fn = await $ MonadicAwaitable fn
-
-
-newtype CompletedAwaitable r = CompletedAwaitable (Either SomeException r)
-
-instance IsAwaitable r (CompletedAwaitable r) where
-  runAwaitable (CompletedAwaitable x) = either throwM pure x
-  cacheAwaitableUnlessPrimitive = pure . toAwaitable
 
 
 completedAwaitable :: Either SomeException r -> Awaitable r
-completedAwaitable result = toAwaitable $ CompletedAwaitable result
+completedAwaitable = either throwM pure
 
 -- | Alias for `pure`.
 successfulAwaitable :: r -> Awaitable r
-successfulAwaitable = completedAwaitable . Right
+successfulAwaitable = pure
 
 failedAwaitable :: SomeException -> Awaitable r
-failedAwaitable = completedAwaitable . Left
-
--- | Create an awaitable from an `STM` transaction.
---
--- The first value or exception returned from the STM transaction will be cached and returned. The STM transacton
--- should not have visible side effects.
---
--- Use `retry` to signal that the awaitable is not yet completed and `throwM`/`throwSTM` to set the awaitable to failed.
-awaitableFromSTM :: forall m a. MonadIO m => STM a -> m (Awaitable a)
-awaitableFromSTM transaction = cacheAwaitableUnlessPrimitive (unsafeAwaitSTM transaction :: Awaitable a)
-
-
-instance {-# OVERLAPS #-} (MonadCatch m, MonadPlus m, MonadFix m) => MonadAwait (ReaderT (QueryFn m) m) where
-  await = runAwaitable
-  unsafeAwaitSTM transaction = do
-    QueryFn querySTMFn <- ask
-    lift $ querySTMFn transaction
-
-newtype QueryFn m = QueryFn (forall a. STM a -> m a)
-
-runQueryT :: forall m a. (forall b. STM b -> m b) -> ReaderT (QueryFn m) m a -> m a
-runQueryT queryFn action = runReaderT action (QueryFn queryFn)
-
-
--- TODO add guard to only allow one thread to step cache
-newtype CachedAwaitable r = CachedAwaitable (TVar (AwaitableStepM r))
-
-cacheAwaitableDefaultImplementation :: (IsAwaitable r a, MonadIO m) => a -> m (Awaitable r)
-cacheAwaitableDefaultImplementation awaitable = toAwaitable . CachedAwaitable <$> liftIO (newTVarIO (runAwaitable awaitable))
-
-instance IsAwaitable r (CachedAwaitable r) where
-  runAwaitable :: forall m. MonadAwait m => CachedAwaitable r -> m r
-  runAwaitable (CachedAwaitable tvar) = go
-    where
-      go :: m r
-      go = unsafeAwaitSTM stepCacheTransaction >>= \case
-        AwaitableCompleted result -> pure result
-        AwaitableFailed ex -> throwM ex
-        -- Cached operation is not yet completed
-        _ -> go
-
-      stepCacheTransaction :: STM (AwaitableStepM r)
-      stepCacheTransaction = do
-        readTVar tvar >>= \case
-          -- Cache needs to be stepped
-          AwaitableStep query fn -> do
-            -- Run the next "querySTM" query requested by the cached operation
-            -- The query might `retry`, which is ok here
-            nextStep <- fn <$> try query
-            -- In case of an incomplete query the caller (/ the monad `m`) can decide what to do (e.g. retry for
-            -- `awaitIO`, abort for `peekAwaitable`)
-            -- Query was successful. Update cache and exit query
-            writeTVar tvar nextStep
-            pure nextStep
-
-          -- Cache was already completed
-          result -> pure result
-
-  cacheAwaitableUnlessPrimitive = pure . toAwaitable
-
-data AwaitableStepM a
-  = AwaitableCompleted a
-  | AwaitableFailed SomeException
-  | forall b. AwaitableStep (STM b) (Either SomeException b -> AwaitableStepM a)
-
-instance Functor AwaitableStepM where
-  fmap fn (AwaitableCompleted x) = AwaitableCompleted (fn x)
-  fmap _ (AwaitableFailed ex) = AwaitableFailed ex
-  fmap fn (AwaitableStep query next) = AwaitableStep query (fmap fn <$> next)
-
-instance Applicative AwaitableStepM where
-  pure = AwaitableCompleted
-  liftA2 fn fx fy = fx >>= \x -> fn x <$> fy
-
-instance Monad AwaitableStepM where
-  (AwaitableCompleted x) >>= fn = fn x
-  (AwaitableFailed ex) >>= _ = AwaitableFailed ex
-  (AwaitableStep query next) >>= fn = AwaitableStep query (next >=> fn)
-
-instance MonadAwait AwaitableStepM where
-  await = runAwaitable
-  unsafeAwaitSTM query = AwaitableStep query (either AwaitableFailed AwaitableCompleted)
-
-instance MonadThrow AwaitableStepM where
-  throwM = AwaitableFailed . toException
-
-instance MonadCatch AwaitableStepM where
-  catch result@(AwaitableCompleted _) _ = result
-  catch result@(AwaitableFailed ex) handler = maybe result handler $ fromException ex
-  catch (AwaitableStep query next) handler = AwaitableStep query (\x -> next x `catch` handler)
-
-instance MonadFail AwaitableStepM where
-  fail = throwM . userError
-
-instance Alternative AwaitableStepM where
-  x <|> y = x `catchAll` const y
-  empty = throwM $ toException $ userError "empty"
-
-instance MonadPlus AwaitableStepM
+failedAwaitable = throwM
 
-instance MonadFix AwaitableStepM where
-  mfix :: forall a. (a -> AwaitableStepM a) -> AwaitableStepM a
-  mfix fn = AwaitableStep newEmptyTMVar applyFix
-    where
-      applyFix :: Either SomeException (TMVar a) -> AwaitableStepM a
-      applyFix (Right var) = do
-        result <- fn $ unsafeDupablePerformIO do
-          atomically (readTMVar var) `catch` \BlockedIndefinitelyOnSTM -> throwIO FixAwaitException
-        storeResult var result
-      applyFix (Left _) = unreachableCodePathM -- `newEmptyTMVar` should never fail
-      storeResult :: TMVar a -> a -> AwaitableStepM a
-      storeResult var x = AwaitableStep (void $ tryPutTMVar var x) (\_ -> pure x)
-
-data FixAwaitException = FixAwaitException
-  deriving stock (Eq, Show)
-  deriving anyclass Exception
 
 
 -- ** AsyncVar
@@ -352,9 +167,7 @@ data FixAwaitException = FixAwaitException
 newtype AsyncVar r = AsyncVar (TMVar (Either SomeException r))
 
 instance IsAwaitable r (AsyncVar r) where
-  runAwaitable (AsyncVar var) = unsafeAwaitSTM $ either throwM pure =<< readTMVar var
-  -- An AsyncVar is a primitive awaitable, so caching is not necessary
-  cacheAwaitableUnlessPrimitive = pure . toAwaitable
+  toAwaitable (AsyncVar var) = unsafeSTMToAwaitable $ either throwM pure =<< readTMVar var
 
 
 newAsyncVarSTM :: STM (AsyncVar r)
@@ -441,17 +254,7 @@ afix_ = void . afix
 
 -- | Completes as soon as either awaitable completes.
 awaitEither :: MonadAwait m => Awaitable ra -> Awaitable rb -> m (Either ra rb)
-awaitEither x y = mkMonadicAwaitable $ stepBoth (runAwaitable x) (runAwaitable y)
-  where
-    stepBoth :: MonadAwait m => AwaitableStepM ra -> AwaitableStepM rb -> m (Either ra rb)
-    stepBoth (AwaitableCompleted resultX) _ = pure $ Left resultX
-    stepBoth (AwaitableFailed ex) _ = throwM ex
-    stepBoth _ (AwaitableCompleted resultY) = pure $ Right resultY
-    stepBoth _ (AwaitableFailed ex) = throwM ex
-    stepBoth stepX@(AwaitableStep transactionX nextX) stepY@(AwaitableStep transactionY nextY) = do
-      unsafeAwaitSTM (eitherSTM (try transactionX) (try transactionY)) >>= \case
-        Left resultX -> stepBoth (nextX resultX) stepY
-        Right resultY -> stepBoth stepX (nextY resultY)
+awaitEither (Awaitable x) (Awaitable y) = unsafeAwaitSTM (eitherSTM x y)
 
 -- | Helper for `awaitEither`
 eitherSTM :: STM a -> STM b -> STM (Either a b)
@@ -461,28 +264,12 @@ eitherSTM x y = fmap Left x `orElse` fmap Right y
 -- Completes as soon as any awaitable in the list is completed and then returns the left-most completed result
 -- (or exception).
 awaitAny :: MonadAwait m => [Awaitable r] -> m r
-awaitAny xs = mkMonadicAwaitable $ stepAll Empty Empty $ runAwaitable <$> fromList xs
-  where
-    stepAll
-      :: MonadAwait m
-      => Seq (STM (Seq (AwaitableStepM r)))
-      -> Seq (AwaitableStepM r)
-      -> Seq (AwaitableStepM r)
-      -> m r
-    stepAll _ _ (AwaitableCompleted result :<| _) = pure result
-    stepAll _ _ (AwaitableFailed ex :<| _) = throwM ex
-    stepAll acc prevSteps (step@(AwaitableStep transaction next) :<| steps) =
-      stepAll
-        do acc |> ((\result -> (prevSteps |> next result) <> steps) <$> try transaction)
-        do prevSteps |> step
-        steps
-    stepAll acc _ Empty = do
-      newAwaitableSteps <- unsafeAwaitSTM $ maybe unreachableCodePathM anySTM $ nonEmpty (toList acc)
-      stepAll Empty Empty newAwaitableSteps
+awaitAny xs = unsafeAwaitSTM $ anySTM $ awaitSTM <$> xs
 
 -- | Helper for `awaitAny`
-anySTM :: NonEmpty (STM a) -> STM a
-anySTM (x :| xs) = x `orElse` maybe retry anySTM (nonEmpty xs)
+anySTM :: [STM a] -> STM a
+anySTM [] = retry
+anySTM (x:xs) = x `orElse` anySTM xs
 
 
 -- | Like `awaitAny` with two awaitables.
diff --git a/src/Quasar/ResourceManager.hs b/src/Quasar/ResourceManager.hs
index e04d66f..be1a028 100644
--- a/src/Quasar/ResourceManager.hs
+++ b/src/Quasar/ResourceManager.hs
@@ -384,7 +384,8 @@ instance IsDisposable DefaultResourceManager where
       primaryBeginDispose :: [Disposable] -> IO DisposeResult
       primaryBeginDispose disposables = do
         (reportExceptionActions, resultAwaitables) <- unzip <$> mapM beginDisposeEntry disposables
-        cachedResultAwaitable <- cacheAwaitable $ mconcat resultAwaitables
+        -- TODO caching was removed; re-optimize later
+        let cachedResultAwaitable = mconcat resultAwaitables
         putAsyncVar_ resultVar cachedResultAwaitable
 
         let
diff --git a/test/Quasar/AwaitableSpec.hs b/test/Quasar/AwaitableSpec.hs
index 10f8823..4f88899 100644
--- a/test/Quasar/AwaitableSpec.hs
+++ b/test/Quasar/AwaitableSpec.hs
@@ -62,70 +62,3 @@ spec = parallel $ do
         threadDelay 100000
         putAsyncVar_ avar2 ()
       awaitAny2 (await avar1) (await avar2)
-
-  describe "cacheAwaitable" do
-    it "can cache an awaitable" $ io do
-      var <- newTVarIO (0 :: Int)
-      awaitable <- cacheAwaitable do
-        unsafeAwaitSTM (modifyTVar var (+ 1)) :: Awaitable ()
-      await awaitable
-      await awaitable
-      readTVarIO var `shouldReturn` 1
-
-    it "can cache a bind" $ io do
-      var1 <- newTVarIO (0 :: Int)
-      var2 <- newTVarIO (0 :: Int)
-      awaitable <- cacheAwaitable do
-        unsafeAwaitSTM (modifyTVar var1 (+ 1)) >>= \_ -> unsafeAwaitSTM (modifyTVar var2 (+ 1)) :: Awaitable ()
-      await awaitable
-      await awaitable
-      readTVarIO var1 `shouldReturn` 1
-      readTVarIO var2 `shouldReturn` 1
-
-    it "can cache an exception" $ io do
-      var <- newMVar (0 :: Int)
-      awaitable <- cacheAwaitable do
-        unsafeAwaitSTM (unsafeIOToSTM (modifyMVar_ var (pure . (+ 1))) >> throwM TestException) :: Awaitable ()
-      await awaitable `shouldThrow` \TestException -> True
-      await awaitable `shouldThrow` \TestException -> True
-      readMVar var `shouldReturn` 1
-
-    it "can cache the left side of an awaitAny2" $ io do
-      var <- newTVarIO (0 :: Int)
-
-      let a1 = unsafeAwaitSTM (modifyTVar var (+ 1)) :: Awaitable ()
-      let a2 = unsafeAwaitSTM retry :: Awaitable ()
-
-      awaitable <- cacheAwaitable $ (awaitAny2 a1 a2 :: Awaitable ())
-
-      await awaitable
-      await awaitable
-      readTVarIO var `shouldReturn` 1
-
-    it "can cache the right side of an awaitAny2" $ io do
-      var <- newTVarIO (0 :: Int)
-
-      let a1 = unsafeAwaitSTM retry :: Awaitable ()
-      let a2 = unsafeAwaitSTM (modifyTVar var (+ 1)) :: Awaitable ()
-
-      awaitable <- cacheAwaitable $ (awaitAny2 a1 a2 :: Awaitable ())
-
-      await awaitable
-      await awaitable
-      readTVarIO var `shouldReturn` 1
-
-    it "can cache an mfix operation" $ io do
-      avar <- newAsyncVar
-
-      r <- cacheAwaitable $ do
-        mfix \x -> do
-          v <- await avar
-          pure (v : x)
-
-      peekAwaitable r `shouldReturn` Nothing
-
-      putAsyncVar_ avar ()
-      Just (():():():_) <- peekAwaitable r
-
-      pure ()
-
-- 
GitLab