From 9700867fe4b10132abc26a37b2deccc7a42dc486 Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Tue, 3 Aug 2021 01:46:34 +0200
Subject: [PATCH] Add Monad instance to Awaitable with support for per-step
 caching

---
 src/Quasar/Awaitable.hs | 166 ++++++++++++++++++++++++++++------------
 1 file changed, 119 insertions(+), 47 deletions(-)

diff --git a/src/Quasar/Awaitable.hs b/src/Quasar/Awaitable.hs
index 63d72e6..df22bf7 100644
--- a/src/Quasar/Awaitable.hs
+++ b/src/Quasar/Awaitable.hs
@@ -7,6 +7,7 @@ module Quasar.Awaitable (
   failedAwaitable,
   completedAwaitable,
   peekAwaitable,
+  awaitEither,
 
   -- * AsyncVar
   AsyncVar,
@@ -24,41 +25,55 @@ module Quasar.Awaitable (
 
 import Control.Concurrent.STM
 import Control.Monad.Catch
+import Control.Monad.Fix (mfix)
+import Control.Monad.Reader
+import Control.Monad.Trans.Class
+import Control.Monad.Trans.Maybe
 import Data.Bifunctor (bimap)
 import Quasar.Prelude
 
 
+
 class IsAwaitable r a | a -> r where
-  peekSTM :: a -> STM (Maybe (Either (Awaitable r) (Either SomeException r)))
-  peekSTM = peekSTM . toAwaitable
+  runAwaitable :: (Monad m) => a -> (forall b. STM (Maybe b) -> m b) -> m (Either SomeException r)
+  runAwaitable self = runAwaitable (toAwaitable self)
 
   toAwaitable :: a -> Awaitable r
-  toAwaitable x = Awaitable (peekSTM x)
+  toAwaitable x = Awaitable $ runAwaitable x
 
-  {-# MINIMAL toAwaitable | peekSTM #-}
+  {-# MINIMAL toAwaitable | runAwaitable #-}
 
 
 awaitIO :: (IsAwaitable r a, MonadIO m) => a -> m r
-awaitIO input = liftIO $ either throwIO pure =<< go (toAwaitable input)
-  where
-    go :: Awaitable r -> IO (Either SomeException r)
-    go x = do
-      stepResult <- atomically $ maybe retry pure =<< peekSTM x
-      either go pure stepResult
+awaitIO awaitable = liftIO $ either throwIO pure =<< runAwaitable awaitable (atomically . (maybe retry pure =<<))
+
+peekAwaitable :: (IsAwaitable r a, MonadIO m) => a -> m (Maybe (Either SomeException r))
+peekAwaitable awaitable = liftIO . runMaybeT $ runAwaitable awaitable (MaybeT . atomically)
 
 
-newtype Awaitable r = Awaitable (STM (Maybe (Either (Awaitable r) (Either SomeException r))))
+newtype Awaitable r = Awaitable (forall m. (Monad m) => (forall b. STM (Maybe b) -> m b) -> m (Either SomeException r))
 
 instance IsAwaitable r (Awaitable r) where
-  peekSTM (Awaitable x) = x
+  runAwaitable (Awaitable x) = x
   toAwaitable = id
 
 instance Functor Awaitable where
-  fmap fn = Awaitable . fmap (fmap (bimap (fmap fn) (fmap fn))) . peekSTM
+  fmap fn (Awaitable x) = Awaitable $ \querySTM -> fn <<$>> x querySTM
+
+instance Applicative Awaitable where
+  pure value = Awaitable $ \_ -> pure (Right value)
+  liftA2 fn (Awaitable fx) (Awaitable fy) = Awaitable $ \querySTM -> liftA2 (liftA2 fn) (fx querySTM) (fy querySTM)
+
+instance Monad Awaitable where
+  (Awaitable fx) >>= fn = Awaitable $ \querySTM -> do
+    fx querySTM >>= \case
+      Left ex -> pure $ Left ex
+      Right x -> runAwaitable (fn x) querySTM
+
 
 
 completedAwaitable :: Either SomeException r -> Awaitable r
-completedAwaitable = Awaitable . pure . Just . Right
+completedAwaitable result = Awaitable $ \_ -> pure result
 
 successfulAwaitable :: r -> Awaitable r
 successfulAwaitable = completedAwaitable . Right
@@ -66,29 +81,73 @@ successfulAwaitable = completedAwaitable . Right
 failedAwaitable :: SomeException -> Awaitable r
 failedAwaitable = completedAwaitable . Left
 
+simpleAwaitable :: STM (Maybe (Either SomeException a)) -> Awaitable a
+simpleAwaitable peekTransaction = Awaitable ($ peekTransaction)
 
-peekAwaitable :: (IsAwaitable r a, MonadIO m) => a -> m (Maybe (Either SomeException r))
-peekAwaitable input = liftIO $ go (toAwaitable input)
-  where
-    go :: Awaitable r -> IO (Maybe (Either SomeException r))
-    go x = atomically (peekSTM x) >>= \case
-      Nothing -> pure Nothing
-      Just (Right result) -> pure $ Just result
-      Just (Left step) -> go step
 
+class Monad m => MonadQuerySTM m where
+  querySTM :: (forall a. STM (Maybe a) -> m a)
+
+instance Monad m => MonadQuerySTM (ReaderT (QuerySTMFunction m) m) where
+  querySTM query = do
+    QuerySTMFunction querySTMFn <- ask
+    lift $ querySTMFn query
+
+data QuerySTMFunction m = QuerySTMFunction (forall b. STM (Maybe b) -> m b)
 
--- | Cache an `Awaitable`
---awaitableFromSTM :: STM (Maybe (Either SomeException r)) -> IO (Awaitable r)
---awaitableFromSTM fn = do
---  cache <- newTVarIO (Left fn)
---  pure . Awaitable $
---    readTVar cache >>= \case
---      Left generatorFn -> do
---        value <- generatorFn
---        writeTVar cache (Right value)
---        pure value
---      Right value -> pure value
 
+newtype CachedAwaitable r = CachedAwaitable (TVar (AwaitableStepM (Either SomeException r)))
+
+instance IsAwaitable r (CachedAwaitable r) where
+  runAwaitable :: forall m. (Monad m) => CachedAwaitable r -> (forall b. STM (Maybe b) -> m b) -> m (Either SomeException r)
+  runAwaitable (CachedAwaitable tvar) querySTM = go
+    where
+      go :: m (Either SomeException r)
+      go = querySTM stepCacheTransaction >>= \case
+        AwaitableCompleted result -> pure result
+        -- Cached operation is not yet completed
+        _ -> go
+
+      stepCacheTransaction :: STM (Maybe (AwaitableStepM (Either SomeException r)))
+      stepCacheTransaction = do
+        readTVar tvar >>= \case
+          -- Cache was already completed
+          result@(AwaitableCompleted _) -> pure $ Just result
+          AwaitableStep transaction fn -> do
+            -- Run the next "querySTM" transaction requested by the cached operation
+            fn <<$>> transaction >>= \case
+              -- In case of an incomplete transaction the caller (/ the monad `m`) can decide what to do (e.g. retry for `awaitIO`, abort for `peekAwaitable`)
+              Nothing -> pure Nothing
+              -- Query was successful. Update cache and exit transaction
+              Just nextStep -> do
+                writeTVar tvar nextStep
+                pure $ Just nextStep
+
+cacheAwaitable :: Awaitable a -> IO (CachedAwaitable a)
+cacheAwaitable awaitable = CachedAwaitable <$> newTVarIO (peekM awaitable peekStep)
+
+data AwaitableStepM a
+  = AwaitableCompleted a
+  | forall b. AwaitableStep (STM (Maybe b)) (b -> AwaitableStepM a)
+
+instance Functor AwaitableStepM where
+  fmap fn (AwaitableCompleted x) = AwaitableCompleted (fn x)
+  fmap fn (AwaitableStep transaction next) = AwaitableStep transaction (fmap fn <$> next)
+
+instance Applicative AwaitableStepM where
+  pure = AwaitableCompleted
+  liftA2 fn fx fy = fx >>= \x -> fn x <$> fy
+
+instance Monad AwaitableStepM where
+  (AwaitableCompleted x) >>= fn = fn x
+  (AwaitableStep transaction next) >>= fn = AwaitableStep transaction (next >=> fn)
+
+instance MonadQuerySTM AwaitableStepM where
+  querySTM transaction = AwaitableStep transaction AwaitableCompleted
+
+
+peekStep :: STM (Maybe a) -> AwaitableStepM a
+peekStep transaction = AwaitableStep transaction AwaitableCompleted
 
 
 -- ** AsyncVar
@@ -97,7 +156,7 @@ peekAwaitable input = liftIO $ go (toAwaitable input)
 newtype AsyncVar r = AsyncVar (TMVar (Either SomeException r))
 
 instance IsAwaitable r (AsyncVar r) where
-  peekSTM (AsyncVar var) = fmap Right <$> tryReadTMVar var
+  runAwaitable (AsyncVar var) = ($ tryReadTMVar var)
 
 
 newAsyncVarSTM :: STM (AsyncVar r)
@@ -136,16 +195,29 @@ putAsyncVarEitherSTM_ var = void . putAsyncVarEitherSTM var
 
 -- * Awaiting multiple asyncs
 
--- TODO
---awaitEither :: (IsAwaitable ra a , IsAwaitable rb b, MonadIO m) => a -> b -> m (Awaitable (Either ra rb))
---awaitEither x y = liftIO $ awaitableFromSTM $ peekEitherSTM x y
---
---peekEitherSTM :: (IsAwaitable ra a , IsAwaitable rb b) => a -> b -> STM (Maybe (Either SomeException (Either ra rb)))
---peekEitherSTM x y =
---  peekSTM x >>= \case
---    Just (Left ex) -> pure (Just (Left ex))
---    Just (Right r) -> pure (Just (Right (Left r)))
---    Nothing -> peekSTM y >>= \case
---      Just (Left ex) -> pure (Just (Left ex))
---      Just (Right r) -> pure (Just (Right (Right r)))
---      Nothing -> pure Nothing
+awaitEither :: (IsAwaitable ra a , IsAwaitable rb b, MonadIO m) => a -> b -> m (Awaitable (Either ra rb))
+awaitEither x y = liftIO $ do
+  let startX = runAwaitable x peekStep
+  let startY = runAwaitable y peekStep
+  pure $ Awaitable $ \querySTM -> groupLefts <$> stepBoth startX startY querySTM
+  where
+    stepBoth :: Monad m => AwaitableStepM ra -> AwaitableStepM rb -> (forall c. STM (Maybe c) -> m c) -> m (Either ra rb)
+    stepBoth (AwaitableCompleted resultX) _ _ = pure $ Left resultX
+    stepBoth _ (AwaitableCompleted resultY) _ = pure $ Right resultY
+    stepBoth stepX@(AwaitableStep transactionX nextX) stepY@(AwaitableStep transactionY nextY) querySTM = do
+      querySTM (peekEitherSTM transactionX transactionY) >>= \case
+        Left resultX -> stepBoth (nextX resultX) stepY querySTM
+        Right resultY -> stepBoth stepX (nextY resultY) querySTM
+
+
+groupLefts :: Either (Either ex a) (Either ex b) -> Either ex (Either a b)
+groupLefts (Left x) = Left <$> x
+groupLefts (Right y) = Right <$> y
+
+peekEitherSTM :: STM (Maybe a) -> STM (Maybe b) -> STM (Maybe (Either a b))
+peekEitherSTM x y =
+  x >>= \case
+    Just r -> pure (Just (Left r))
+    Nothing -> y >>= \case
+      Just r -> pure (Just (Right r))
+      Nothing -> pure Nothing
-- 
GitLab