diff --git a/src/Quasar/Awaitable.hs b/src/Quasar/Awaitable.hs index 47ff5776c41d347499d1cbc4d6e0be8011032e3b..e11bcd0f478f0e4f4e31680055e062977ed9238d 100644 --- a/src/Quasar/Awaitable.hs +++ b/src/Quasar/Awaitable.hs @@ -5,12 +5,13 @@ module Quasar.Awaitable ( peekAwaitable, -- * Awaitable - IsAwaitable(..), + IsAwaitable(toAwaitable), Awaitable, successfulAwaitable, failedAwaitable, completedAwaitable, awaitableFromSTM, + cacheAwaitable, -- * Awaitable helpers afix, @@ -43,13 +44,11 @@ module Quasar.Awaitable ( failAsyncVarSTM_, putAsyncVarEitherSTM_, readAsyncVarSTM, - - -- * Implementation helpers - cacheAwaitableDefaultImplementation, ) where import Control.Applicative (empty) import Control.Concurrent.STM +import Control.Exception (BlockedIndefinitelyOnSTM(..)) import Control.Monad.Catch import Control.Monad.Reader import Control.Monad.Writer (WriterT) @@ -59,10 +58,11 @@ import Control.Monad.Trans.Maybe import Data.List.NonEmpty (NonEmpty(..), nonEmpty) import Data.Foldable (toList) import Data.Sequence +import GHC.IO (unsafeDupablePerformIO) import Quasar.Prelude -class (MonadCatch m, MonadFail m, MonadPlus m) => MonadAwait m where +class (MonadCatch m, MonadFail m, MonadPlus m, MonadFix m) => MonadAwait m where -- | Wait until an awaitable is completed and then return it's value (or throw an exception). await :: IsAwaitable r a => a -> m r @@ -113,28 +113,32 @@ peekAwaitable awaitable = liftIO $ runMaybeT $ runQueryT queryFn (runAwaitable a class IsAwaitable r a | a -> r where - -- | Run the awaitable. When interacting with an awaitable you usually want to use `await` instead. `runAwaitable` is exposed to manually implement an instance - -- of `IsAwaitable`. + toAwaitable :: a -> Awaitable r + toAwaitable = Awaitable + + -- | Run the awaitable. When interacting with an awaitable you usually want to use `await` instead. `runAwaitable` is + -- used to manually implement an instance of `IsAwaitable`. -- -- The implementation of `async` calls `runAwaitable` in most monads, so the implementation of `runAwaitable` must -- not call `async` without deconstructing first. runAwaitable :: (MonadAwait m) => a -> m r runAwaitable self = runAwaitable (toAwaitable self) - cacheAwaitable :: MonadIO m => a -> m (Awaitable r) - cacheAwaitable self = cacheAwaitable (toAwaitable self) + cacheAwaitableUnlessPrimitive :: MonadIO m => a -> m (Awaitable r) + cacheAwaitableUnlessPrimitive self = cacheAwaitableUnlessPrimitive (toAwaitable self) - toAwaitable :: a -> Awaitable r - toAwaitable = Awaitable + {-# MINIMAL toAwaitable | (runAwaitable, cacheAwaitableUnlessPrimitive) #-} - {-# MINIMAL toAwaitable | (runAwaitable, cacheAwaitable) #-} + +cacheAwaitable :: MonadIO m => Awaitable r -> m (Awaitable r) +cacheAwaitable = cacheAwaitableUnlessPrimitive data Awaitable r = forall a. IsAwaitable r a => Awaitable a instance IsAwaitable r (Awaitable r) where runAwaitable (Awaitable x) = runAwaitable x - cacheAwaitable (Awaitable x) = cacheAwaitable x + cacheAwaitableUnlessPrimitive (Awaitable x) = cacheAwaitableUnlessPrimitive x toAwaitable = id instance MonadAwait Awaitable where @@ -174,13 +178,16 @@ instance Alternative Awaitable where instance MonadPlus Awaitable +instance MonadFix Awaitable where + mfix fn = mkMonadicAwaitable $ mfix \x -> runAwaitable (fn x) + newtype MonadicAwaitable r = MonadicAwaitable (forall m. MonadAwait m => m r) instance IsAwaitable r (MonadicAwaitable r) where runAwaitable (MonadicAwaitable x) = x - cacheAwaitable = cacheAwaitableDefaultImplementation + cacheAwaitableUnlessPrimitive = cacheAwaitableDefaultImplementation mkMonadicAwaitable :: MonadAwait m => (forall f. (MonadAwait f) => f r) -> m r mkMonadicAwaitable fn = await $ MonadicAwaitable fn @@ -190,7 +197,7 @@ newtype CompletedAwaitable r = CompletedAwaitable (Either SomeException r) instance IsAwaitable r (CompletedAwaitable r) where runAwaitable (CompletedAwaitable x) = either throwM pure x - cacheAwaitable = pure . toAwaitable + cacheAwaitableUnlessPrimitive = pure . toAwaitable completedAwaitable :: Either SomeException r -> Awaitable r @@ -210,10 +217,10 @@ failedAwaitable = completedAwaitable . Left -- -- Use `retry` to signal that the awaitable is not yet completed and `throwM`/`throwSTM` to set the awaitable to failed. awaitableFromSTM :: forall m a. MonadIO m => STM a -> m (Awaitable a) -awaitableFromSTM transaction = cacheAwaitable (unsafeAwaitSTM transaction :: Awaitable a) +awaitableFromSTM transaction = cacheAwaitableUnlessPrimitive (unsafeAwaitSTM transaction :: Awaitable a) -instance {-# OVERLAPS #-} (MonadCatch m, MonadFail m, MonadPlus m) => MonadAwait (ReaderT (QueryFn m) m) where +instance {-# OVERLAPS #-} (MonadCatch m, MonadFail m, MonadPlus m, MonadFix m) => MonadAwait (ReaderT (QueryFn m) m) where await = runAwaitable unsafeAwaitSTM transaction = do QueryFn querySTMFn <- ask @@ -225,6 +232,7 @@ runQueryT :: forall m a. (forall b. STM b -> m b) -> ReaderT (QueryFn m) m a -> runQueryT queryFn action = runReaderT action (QueryFn queryFn) +-- TODO add guard to only allow one thread to step cache newtype CachedAwaitable r = CachedAwaitable (TVar (AwaitableStepM r)) cacheAwaitableDefaultImplementation :: (IsAwaitable r a, MonadIO m) => a -> m (Awaitable r) @@ -254,10 +262,11 @@ instance IsAwaitable r (CachedAwaitable r) where -- Query was successful. Update cache and exit query writeTVar tvar nextStep pure nextStep + -- Cache was already completed result -> pure result - cacheAwaitable = pure . toAwaitable + cacheAwaitableUnlessPrimitive = pure . toAwaitable data AwaitableStepM a = AwaitableCompleted a @@ -299,6 +308,23 @@ instance Alternative AwaitableStepM where instance MonadPlus AwaitableStepM where +instance MonadFix AwaitableStepM where + mfix :: forall a. (a -> AwaitableStepM a) -> AwaitableStepM a + mfix fn = AwaitableStep newEmptyTMVar applyFix + where + applyFix :: Either SomeException (TMVar a) -> AwaitableStepM a + applyFix (Right var) = do + result <- fn $ unsafeDupablePerformIO do + atomically (readTMVar var) `catch` \BlockedIndefinitelyOnSTM -> throwIO FixAwaitException + storeResult var result + applyFix (Left _) = impossibleCodePathM -- `newEmptyTMVar` should never fail + storeResult :: TMVar a -> a -> AwaitableStepM a + storeResult var x = AwaitableStep (void $ tryPutTMVar var x) (\_ -> pure x) + +data FixAwaitException = FixAwaitException + deriving stock (Eq, Show) + deriving anyclass Exception + -- ** AsyncVar @@ -308,7 +334,7 @@ newtype AsyncVar r = AsyncVar (TMVar (Either SomeException r)) instance IsAwaitable r (AsyncVar r) where runAwaitable (AsyncVar var) = unsafeAwaitSTM $ either throwM pure =<< readTMVar var -- An AsyncVar is a primitive awaitable, so caching is not necessary - cacheAwaitable = pure . toAwaitable + cacheAwaitableUnlessPrimitive = pure . toAwaitable newAsyncVarSTM :: STM (AsyncVar r) diff --git a/test/Quasar/AwaitableSpec.hs b/test/Quasar/AwaitableSpec.hs index 504564300c742114bd1ea45b32614105d2615769..90f3aef3b3b59d79bc8375c7994270c83ca22175 100644 --- a/test/Quasar/AwaitableSpec.hs +++ b/test/Quasar/AwaitableSpec.hs @@ -113,3 +113,19 @@ spec = parallel $ do await awaitable await awaitable readTVarIO var `shouldReturn` 1 + + it "can cache an mfix operation" $ io do + avar <- newAsyncVar + + r <- cacheAwaitable $ do + mfix \x -> do + v <- await avar + pure (v : x) + + peekAwaitable r `shouldReturn` Nothing + + putAsyncVar_ avar () + Just (():():():_) <- peekAwaitable r + + pure () +