From a9c2fe5b7ae7d95d2f587513d4e05eab02bc3739 Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Tue, 3 May 2022 16:51:34 +0200
Subject: [PATCH] Improve observable (use TDisposer, explicitly track observer
 resources)

Co-authored-by: Jan Beinke <git@janbeinke.com>
---
 src/Quasar/Observable.hs | 303 ++++++++++++++++++---------------------
 1 file changed, 138 insertions(+), 165 deletions(-)

diff --git a/src/Quasar/Observable.hs b/src/Quasar/Observable.hs
index 0faa166..051b8e9 100644
--- a/src/Quasar/Observable.hs
+++ b/src/Quasar/Observable.hs
@@ -4,14 +4,15 @@ module Quasar.Observable (
   ObservableState(..),
   IsRetrievable(..),
   IsObservable(..),
+  observe,
   observe_,
   observeIO,
   observeIO_,
 
   -- ** Control flow utilities
+  observeWith,
   observeBlocking,
-  observeUntil,
-  observeUntil_,
+  observeAsync,
 
   -- * ObservableVar
   ObservableVar,
@@ -25,7 +26,11 @@ module Quasar.Observable (
   -- * Helpers
 
   -- ** Helper types
-  ObservableCallback,
+  ObserverContext,
+  ObserverCallback,
+  observableCallback,
+  execObservableCallback,
+  mapObservableCallback,
 
   -- ** Observable implementation primitive
   ObservablePrim,
@@ -42,11 +47,12 @@ import Control.Monad.Catch
 import Control.Monad.Except
 import Control.Monad.Trans.Maybe
 import Data.HashMap.Strict qualified as HM
-import Data.IORef
 import Data.Unique
+import Quasar.Async
+import Quasar.Async.STMHelper
+import Quasar.Exceptions
 import Quasar.Prelude
 import Quasar.MonadQuasar
-import Quasar.MonadQuasar.Misc
 import Quasar.Resources
 
 data ObservableState a
@@ -81,18 +87,14 @@ class IsObservable r a | a -> r where
   -- | Register a callback to observe changes. The callback is called when the value changes, but depending on the
   -- delivery method (e.g. network) intermediate values may be skipped.
   --
-  -- A correct implementation of observe must call the callback during registration (if no value is available
+  -- A correct implementation of `attachObserver` must call the callback during registration (if no value is available
   -- immediately an `ObservableLoading` will be delivered).
   --
   -- The callback should return without blocking, otherwise other callbacks will be delayed. If the value can't be
   -- processed immediately, use `observeBlocking` instead or manually pass the value to a thread that processes the
   -- data.
-  observe
-    :: (MonadQuasar m, MonadSTM m)
-    => a -- ^ observable
-    -> ObservableCallback r -- ^ callback
-    -> m Disposer
-  observe observable = observe (toObservable observable)
+  attachObserver :: a -> ObserverCallback r -> ObserverContext -> STM TDisposer
+  attachObserver observable = attachObserver (toObservable observable)
 
   toObservable :: a -> Observable r
   toObservable = Observable
@@ -100,38 +102,68 @@ class IsObservable r a | a -> r where
   mapObservable :: (r -> r2) -> a -> Observable r2
   mapObservable f = Observable . MappedObservable f . toObservable
 
-  {-# MINIMAL toObservable | observe #-}
+  {-# MINIMAL toObservable | attachObserver #-}
+
 
+observe
+  :: (MonadQuasar m, MonadSTM m)
+  => Observable a
+  -> (ObservableState a -> STM ()) -- ^ callback
+  -> m Disposer
+observe observable callbackFn = do
+  -- Each observer needs a dedicated scope to guarantee, that the whole observer is detached when the provided callback (or the observable implementation) fails.
+  scope <- newResourceScope
+  liftSTM do
+    let ctx = ObserverContext (quasarIOWorker scope) (quasarExceptionSink scope) (quasarResourceManager scope)
+    disposer <- attachObserver observable (observableCallback callbackFn) ctx
+    attachResource (quasarResourceManager scope) disposer
+    pure $ toDisposer (quasarResourceManager scope)
 
 observe_
-    :: (IsObservable r a, MonadQuasar m, MonadSTM m)
-    => a -- ^ observable
-    -> ObservableCallback r -- ^ callback
+    :: (MonadQuasar m, MonadSTM m)
+    => Observable a
+    -> (ObservableState a -> STM ()) -- ^ callback
     -> m ()
 observe_ observable callback = liftQuasarSTM $ void $ observe observable callback
 
 observeIO
-  :: (IsObservable r a, MonadQuasar m, MonadIO m)
-  => a -- ^ observable
-  -> ObservableCallback r -- ^ callback
+  :: (MonadQuasar m, MonadIO m)
+  => Observable a
+  -> (ObservableState a -> STM ()) -- ^ callback
   -> m Disposer
 observeIO observable callback = quasarAtomically $ observe observable callback
 
 observeIO_
-  :: (IsObservable r a, MonadQuasar m, MonadIO m)
-  => a -- ^ observable
-  -> ObservableCallback r -- ^ callback
+  :: (MonadQuasar m, MonadIO m)
+  => Observable a
+  -> (ObservableState a -> STM ()) -- ^ callback
   -> m ()
 observeIO_ observable callback = quasarAtomically $ observe_ observable callback
 
 
-type ObservableCallback a = ObservableState a -> QuasarSTM ()
+
+-- | Context for an observer, generated by calling `observe`.
+data ObserverContext = ObserverContext TIOWorker ExceptionSink ResourceManager
+
+-- | Callback wrapper, internally used for `attachObserver`. Using the newtype prevents the callback from being called without appropriate error handling. This ensures exceptions from callbacks do not interfere with observable updates.
+newtype ObserverCallback a = ObserverCallback (ObservableState a -> STM ())
+
+observableCallback :: (ObservableState a -> STM ()) -> ObserverCallback a
+observableCallback = ObserverCallback
+
+execObservableCallback :: ObserverCallback a -> ObserverContext -> ObservableState a -> STM ()
+execObservableCallback (ObserverCallback fn) (ObserverContext _ sink _) arg =
+  fn arg `catchAll` \ex -> throwToExceptionSink sink ex
+
+mapObservableCallback :: (ObservableState b -> ObservableState a) -> ObserverCallback a -> ObserverCallback b
+mapObservableCallback fn (ObserverCallback cb) = ObserverCallback (cb . fn)
+
 
 
 -- | Existential quantification wrapper for the IsObservable type class.
 data Observable r = forall a. IsObservable r a => Observable a
 instance IsObservable r (Observable r) where
-  observe (Observable o) = observe o
+  attachObserver (Observable o) = attachObserver o
   toObservable = id
   mapObservable f (Observable o) = mapObservable f o
 
@@ -139,18 +171,18 @@ instance Functor Observable where
   fmap f = mapObservable f
 
 instance Applicative Observable where
-  pure = toObservable . ConstObservable
+  pure value = toObservable (ConstObservable (ObservableValue value))
   liftA2 fn x y = toObservable $ LiftA2Observable fn x y
 
 instance Monad Observable where
-  x >>= f = toObservable $ BindObservable x f
+  x >>= f = bindObservable x f
 
 instance MonadThrow Observable where
   throwM :: forall e v. Exception e => e -> Observable v
-  throwM = toObservable . ThrowObservable @v . toException
+  throwM ex = toObservable (ConstObservable (ObservableNotAvailable (toException ex)))
 
 instance MonadCatch Observable where
-  catch action handler = toObservable $ CatchObservable action handler
+  catch action handler = catchObservable action handler
 
 instance MonadFail Observable where
   fail = throwM . userError
@@ -176,84 +208,54 @@ instance Monoid a => Monoid (Observable a) where
 -- The handler is allowed to block. When the value changes while the handler is running the handler will be run again
 -- after it completes; when the value changes multiple times it will only be executed once (with the latest value).
 observeBlocking
-  :: (IsObservable r a, MonadQuasar m, MonadIO m, MonadMask m)
-  => a
+  :: (MonadQuasar m, MonadIO m, MonadMask m)
+  => Observable r
   -> (ObservableState r -> m ())
-  -> m b
+  -> m a
 observeBlocking observable handler = do
+  observeWith observable \fetchNext -> forever do
+    msg <- atomically $ fetchNext
+    handler msg
+
+observeAsync
+  :: (MonadQuasar m, MonadIO m)
+  => Observable r
+  -> (ObservableState r -> QuasarIO ())
+  -> m (Async a)
+observeAsync observable handler = async $ observeBlocking observable handler
+
+
+observeWith
+  :: (MonadQuasar m, MonadIO m, MonadMask m)
+  => Observable r
+  -> (STM (ObservableState r) -> m a)
+  -> m a
+observeWith observable fn = do
   var <- liftIO newEmptyTMVarIO
 
-  bracket
-    do
-      quasarAtomically $ observe observable \msg -> liftSTM do
-        void $ tryTakeTMVar var
-        putTMVar var msg
-    dispose
-    \_ -> forever do
-      msg <- liftIO $ atomically $ takeTMVar var
-      handler msg
+  bracket (aquire var) dispose
+    \_ -> fn (takeTMVar var)
+  where
+    aquire var = observeIO observable \msg -> do
+      void $ tryTakeTMVar var
+      putTMVar var msg
 
 
 -- | Internal control flow exception for `observeWhile` and `observeWhile_`.
 data ObserveWhileCompleted = ObserveWhileCompleted
   deriving stock (Eq, Show)
 
-instance Exception ObserveWhileCompleted
-
--- | Observe until the callback returns `Just`.
-observeUntil
-  :: (IsObservable r a, MonadQuasar m, MonadIO m, MonadMask m)
-  => a
-  -> (ObservableState r -> m (Maybe b))
-  -> m b
-observeUntil observable callback = do
-  resultVar <- liftIO $ newIORef unreachableCodePath
-  observeUntil_ observable \msg -> do
-    callback msg >>= \case
-      Just result -> do
-        liftIO $ writeIORef resultVar result
-        pure False
-      Nothing -> pure True
-
-  liftIO $ readIORef resultVar
-
-
--- | Observe until the callback returns `False`.
-observeUntil_
-  :: (IsObservable r a, MonadQuasar m, MonadIO m, MonadMask m)
-  => a
-  -> (ObservableState r -> m Bool)
-  -> m ()
-observeUntil_ observable callback =
-  catch
-    do
-      observeBlocking observable \msg -> do
-        continue <- callback msg
-        unless continue $ throwM ObserveWhileCompleted
-    \ObserveWhileCompleted -> pure ()
-
-
-newtype ConstObservable a = ConstObservable a
-instance IsRetrievable a (ConstObservable a) where
-  retrieve (ConstObservable x) = pure x
-instance IsObservable a (ConstObservable a) where
-  observe (ConstObservable x) callback = liftQuasarSTM do
-    callback $ ObservableValue x
-    pure trivialDisposer
-
 
-newtype ThrowObservable a = ThrowObservable SomeException
-instance IsRetrievable a (ThrowObservable a) where
-  retrieve (ThrowObservable ex) = throwM ex
-instance IsObservable a (ThrowObservable a) where
-  observe (ThrowObservable ex) callback = liftQuasarSTM do
-    callback $ ObservableNotAvailable ex
-    pure trivialDisposer
+newtype ConstObservable a = ConstObservable (ObservableState a)
+instance IsObservable a (ConstObservable a) where
+  attachObserver (ConstObservable state) callback sink = do
+    execObservableCallback callback sink state
+    pure mempty
 
 
 data MappedObservable a = forall b. MappedObservable (b -> a) (Observable b)
 instance IsObservable a (MappedObservable a) where
-  observe (MappedObservable fn observable) callback = observe observable (callback . fmap fn)
+  attachObserver (MappedObservable fn observable) callback sink = attachObserver observable (mapObservableCallback (fmap fn) callback) sink
   mapObservable f1 (MappedObservable f2 upstream) = toObservable $ MappedObservable (f1 . f2) upstream
 
 
@@ -264,82 +266,52 @@ instance IsObservable a (MappedObservable a) where
 data LiftA2Observable r = forall a b. LiftA2Observable (a -> b -> r) (Observable a) (Observable b)
 
 instance IsObservable a (LiftA2Observable a) where
-  observe (LiftA2Observable fn fx fy) callback = liftQuasarSTM do
+  attachObserver (LiftA2Observable fn fx fy) callback sink = do
     var0 <- newTVar Nothing
     var1 <- newTVar Nothing
     let callCallback = do
-          mergedValue <- liftSTM $ runMaybeT $ liftA2 (liftA2 fn) (MaybeT (readTVar var0)) (MaybeT (readTVar var1))
+          mergedValue <- runMaybeT $ liftA2 (liftA2 fn) (MaybeT (readTVar var0)) (MaybeT (readTVar var1))
           -- Run the callback only once both values have been received
-          mapM_ callback mergedValue
-    dx <- observe fx (\update -> liftSTM (writeTVar var0 (Just update)) >> callCallback)
-    dy <- observe fy (\update -> liftSTM (writeTVar var1 (Just update)) >> callCallback)
+          mapM_ (execObservableCallback callback sink) mergedValue
+    dx <- attachObserver fx (observableCallback \update -> writeTVar var0 (Just update) >> callCallback) sink
+    dy <- attachObserver fy (observableCallback \update -> writeTVar var1 (Just update) >> callCallback) sink
     pure $ dx <> dy
 
   mapObservable f1 (LiftA2Observable f2 fx fy) = toObservable $ LiftA2Observable (\x y -> f1 (f2 x y)) fx fy
 
 
-data BindObservable a = forall b. BindObservable (Observable b) (b -> Observable a)
+-- Implementation for bind and catch
+data ObservableStep a = forall b. ObservableStep (Observable b) (ObservableState b -> Observable a)
 
-instance IsObservable a (BindObservable a) where
-  observe (BindObservable fx fn) callback = liftQuasarSTM do
-    -- TODO Dispose in STM to remove potential extraneous (/invalid?) updates while disposing
-    callback ObservableLoading
-    keyVar <- newTVar =<< newUniqueSTM
-    rightDisposerVar <- newTVar trivialDisposer
-    leftDisposer <- observe fx (leftCallback keyVar rightDisposerVar)
-    registerDisposeAction do
-      dispose leftDisposer
-      -- Needs to be disposed in order since there is no way to unsubscribe atomically yet
-      dispose =<< readTVarIO rightDisposerVar
-    where
-      leftCallback keyVar rightDisposerVar lmsg = do
-        disposeEventually_ =<< readTVar rightDisposerVar
-        key <- newUniqueSTM
-        -- Dispose is not instant, so a key is used to disarm the callback derived from the last (now outdated) value
-        writeTVar keyVar key
-        disposer <-
-          case lmsg of
-            ObservableValue x -> observe (fn x) (rightCallback key)
-            ObservableLoading -> trivialDisposer <$ callback ObservableLoading
-            ObservableNotAvailable ex -> trivialDisposer <$ callback (ObservableNotAvailable ex)
-        writeTVar rightDisposerVar disposer
-        where
-          rightCallback :: Unique -> ObservableCallback a
-          rightCallback callbackKey rmsg = do
-            activeKey <- readTVar keyVar
-            when (callbackKey == activeKey) (callback rmsg)
-
-  mapObservable f (BindObservable fx fn) = toObservable $ BindObservable fx (f <<$>> fn)
-
-
-data CatchObservable e a = Exception e => CatchObservable (Observable a) (e -> Observable a)
-
-instance IsObservable a (CatchObservable e a) where
-  observe (CatchObservable fx fn) callback = liftQuasarSTM do
-    callback ObservableLoading
-    keyVar <- newTVar =<< newUniqueSTM
-    rightDisposerVar <- liftSTM $ newTVar trivialDisposer
-    leftDisposer <- observe fx (leftCallback keyVar rightDisposerVar)
-    registerDisposeAction do
-      dispose leftDisposer
-      -- Needs to be disposed in order since there is no way to unsubscribe atomically yet
-      dispose =<< readTVarIO rightDisposerVar
+instance IsObservable a (ObservableStep a) where
+  attachObserver (ObservableStep fx fn) callback ctx@(ObserverContext worker sink _) = do
+    -- Callback isn't called immediately, since subscribing to fx and fn also guarantees a callback.
+    rightDisposerVar <- newTVar mempty
+    left <- attachObserver fx (leftCallback rightDisposerVar) ctx
+    newUnmanagedSTMDisposer (disposeFn left rightDisposerVar) worker sink
     where
-      leftCallback keyVar rightDisposerVar lmsg = do
-        disposeEventually_ =<< readTVar rightDisposerVar
-        key <- newUniqueSTM
-        -- Dispose is not instant, so a key is used to disarm the callback derived from the last (now outdated) value
-        writeTVar keyVar key
-        disposer <-
-          case lmsg of
-            ObservableNotAvailable (fromException -> Just ex) -> observe (fn ex) (rightCallback key)
-            _ -> trivialDisposer <$ callback lmsg
+      leftCallback rightDisposerVar = observableCallback \lmsg -> do
+        disposeTDisposer =<< readTVar rightDisposerVar
+        disposer <- attachObserver (fn lmsg) callback ctx
         writeTVar rightDisposerVar disposer
-        where
-          rightCallback :: Unique -> ObservableCallback a
-          rightCallback callbackKey rmsg = do
-            activeKey <- readTVar keyVar
-            when (callbackKey == activeKey) (callback rmsg)
+
+      disposeFn :: TDisposer -> TVar TDisposer -> STM ()
+      disposeFn leftDisposer rightDisposerVar = do
+        rightDisposer <- swapTVar rightDisposerVar mempty
+        disposeTDisposer (leftDisposer <> rightDisposer)
+
+  mapObservable f (ObservableStep fx fn) = toObservable $ ObservableStep fx (f <<$>> fn)
+
+bindObservable :: (Observable b) -> (b -> Observable a) -> Observable a
+bindObservable fx fn = toObservable $ ObservableStep fx \case
+  ObservableValue x -> fn x
+  ObservableLoading -> toObservable (ConstObservable ObservableLoading)
+  ObservableNotAvailable ex -> throwM ex
+
+catchObservable :: Exception e => (Observable a) -> (e -> Observable a) -> Observable a
+catchObservable fx fn = toObservable $ ObservableStep fx \case
+  ObservableNotAvailable (fromException -> Just ex) -> fn ex
+  state -> toObservable (ConstObservable state)
 
 
 newtype ObserverRegistry a = ObserverRegistry (TVar (HM.HashMap Unique (ObservableState a -> STM ())))
@@ -350,13 +322,14 @@ newObserverRegistry = ObserverRegistry <$> newTVar mempty
 newObserverRegistryIO :: MonadIO m => m (ObserverRegistry a)
 newObserverRegistryIO = liftIO $ ObserverRegistry <$> newTVarIO mempty
 
-registerObserver :: ObserverRegistry a -> ObservableCallback a -> ObservableState a -> QuasarSTM Disposer
-registerObserver (ObserverRegistry var) callback currentState = do
-  quasar <- askQuasar
+registerObserver :: ObserverRegistry a -> ObserverCallback a -> ObserverContext -> ObservableState a -> STM TDisposer
+registerObserver (ObserverRegistry var) callback ctx@(ObserverContext worker sink rm) currentState = do
   key <- newUniqueSTM
-  modifyTVar var (HM.insert key (execForeignQuasarSTM quasar . callback))
-  disposer <- registerDisposeTransaction $ modifyTVar var (HM.delete key)
-  callback currentState
+  let execCallbackFn = execObservableCallback callback ctx
+  modifyTVar var (HM.insert key execCallbackFn)
+  disposer <- newUnmanagedSTMDisposer (modifyTVar var (HM.delete key)) worker sink
+  attachResource rm disposer
+  execCallbackFn currentState
   pure disposer
 
 updateObservers :: ObserverRegistry a -> ObservableState a -> STM ()
@@ -373,8 +346,8 @@ instance IsRetrievable a (ObservableVar a) where
   retrieve (ObservableVar var _registry) = liftIO $ readTVarIO var
 
 instance IsObservable a (ObservableVar a) where
-  observe (ObservableVar var registry) callback = liftQuasarSTM do
-    registerObserver registry callback . ObservableValue =<< readTVar var
+  attachObserver (ObservableVar var registry) callback sink =
+    registerObserver registry callback sink . ObservableValue =<< readTVar var
 
 newObservableVar :: MonadSTM m => a -> m (ObservableVar a)
 newObservableVar x = liftSTM $ ObservableVar <$> newTVar x <*> newObserverRegistry
@@ -414,8 +387,8 @@ instance IsRetrievable a (ObservablePrim a) where
       ObservableNotAvailable ex -> throwM ex
 
 instance IsObservable a (ObservablePrim a) where
-  observe (ObservablePrim var registry) callback = liftQuasarSTM do
-    registerObserver registry callback =<< readTVar var
+  attachObserver (ObservablePrim var registry) callback sink = do
+    registerObserver registry callback sink =<< readTVar var
 
 newObservablePrim :: MonadSTM m => ObservableState a -> m (ObservablePrim a)
 newObservablePrim x = liftSTM $ ObservablePrim <$> newTVar x <*> newObserverRegistry
-- 
GitLab