From 2d71739d90a4435a397070aace4afcdec2f305e0 Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Tue, 29 Mar 2022 22:23:15 +0200
Subject: [PATCH] Fix observable unsubscribe releasing unrelated resources

---
 src/Quasar/Observable.hs | 86 ++++++++++++++++++++--------------------
 1 file changed, 44 insertions(+), 42 deletions(-)

diff --git a/src/Quasar/Observable.hs b/src/Quasar/Observable.hs
index 9b70c24..c52b0f9 100644
--- a/src/Quasar/Observable.hs
+++ b/src/Quasar/Observable.hs
@@ -2,6 +2,7 @@ module Quasar.Observable (
   -- * Observable core types
   IsRetrievable(..),
   IsObservable(..),
+  observe_,
   Observable(..),
   ObservableState(..),
   --toObservableUpdate,
@@ -62,14 +63,6 @@ instance Monad ObservableState where
   (ObservableNotAvailable ex) >>= _ = ObservableNotAvailable ex
 
 
--- TODO rename or delete
---toObservableUpdate :: MonadThrow m => ObservableState a -> m (Maybe a)
---toObservableUpdate (ObservableValue value) = pure $ Just value
---toObservableUpdate ObservableLoading = pure Nothing
---toObservableUpdate (ObservableNotAvailable ex) = throwM ex
-
-
-
 class IsRetrievable r a | a -> r where
   retrieve :: (MonadQuasar m, MonadIO m) => a -> m r
 
@@ -87,7 +80,7 @@ class IsRetrievable r a => IsObservable r a | a -> r where
     :: (MonadQuasar m)
     => a -- ^ observable
     -> ObservableCallback r -- ^ callback
-    -> m ()
+    -> m [Disposer]
   observe observable = observe (toObservable observable)
 
   pingObservable
@@ -105,6 +98,14 @@ class IsRetrievable r a => IsObservable r a | a -> r where
   {-# MINIMAL toObservable | observe, pingObservable #-}
 
 
+observe_
+    :: (IsObservable r a, MonadQuasar m)
+    => a -- ^ observable
+    -> ObservableCallback r -- ^ callback
+    -> m ()
+observe_ observable callback = void $ observe observable callback
+
+
 type ObservableCallback v = ObservableState v -> QuasarSTM ()
 
 
@@ -158,16 +159,15 @@ observeBlocking
   -> (ObservableState r -> m ())
   -> m b
 observeBlocking observable handler = do
-  -- `withResourceScope` removes the `observe` callback when the `handler` fails.
-  -- TODO this also releases all resources when the handler fails - is that correct? if so it should be documented
-  withResourceScope do
-    var <- liftIO newEmptyTMVarIO
+  var <- liftIO newEmptyTMVarIO
 
-    observe observable \msg -> liftSTM do
-      void $ tryTakeTMVar var
-      putTMVar var msg
-
-    forever do
+  bracket
+    do
+      observe observable \msg -> liftSTM do
+        void $ tryTakeTMVar var
+        putTMVar var msg
+    dispose
+    \_ -> forever do
       msg <- liftIO $ atomically $ takeTMVar var
       handler msg
 
@@ -215,8 +215,9 @@ newtype ConstObservable a = ConstObservable a
 instance IsRetrievable a (ConstObservable a) where
   retrieve (ConstObservable x) = pure x
 instance IsObservable a (ConstObservable a) where
-  observe (ConstObservable x) callback = ensureQuasarSTM $
+  observe (ConstObservable x) callback = ensureQuasarSTM do
     callback $ ObservableValue x
+    pure []
   pingObservable _ = pure ()
 
 
@@ -224,8 +225,9 @@ newtype ThrowObservable a = ThrowObservable SomeException
 instance IsRetrievable a (ThrowObservable a) where
   retrieve (ThrowObservable ex) = throwM ex
 instance IsObservable a (ThrowObservable a) where
-  observe (ThrowObservable ex) callback = ensureQuasarSTM $
+  observe (ThrowObservable ex) callback = ensureQuasarSTM do
     callback $ ObservableNotAvailable ex
+    pure []
   pingObservable _ = pure ()
 
 
@@ -258,8 +260,9 @@ instance IsObservable a (LiftA2Observable a) where
           mergedValue <- liftSTM $ runMaybeT $ liftA2 (liftA2 fn) (MaybeT (readTVar var0)) (MaybeT (readTVar var1))
           -- Run the callback only once both values have been received
           mapM_ callback mergedValue
-    observe fx (\update -> liftSTM (writeTVar var0 (Just update)) >> callCallback)
-    observe fy (\update -> liftSTM (writeTVar var1 (Just update)) >> callCallback)
+    dx <- observe fx (\update -> liftSTM (writeTVar var0 (Just update)) >> callCallback)
+    dy <- observe fy (\update -> liftSTM (writeTVar var1 (Just update)) >> callCallback)
+    pure $ dx <> dy
 
   pingObservable (LiftA2Observable _ fx fy) = liftQuasarIO do
     -- LATER: keep backpressure for parallel network requests
@@ -281,7 +284,7 @@ instance IsObservable a (BindObservable a) where
   observe (BindObservable fx fn) callback = ensureQuasarSTM do
     callback ObservableLoading
     keyVar <- newTVar =<< newUniqueSTM
-    disposableVar <- liftSTM $ newTVar trivialDisposer
+    disposableVar <- liftSTM $ newTVar []
     observe fx (leftCallback keyVar disposableVar)
     where
       leftCallback keyVar disposableVar lmsg = do
@@ -289,11 +292,11 @@ instance IsObservable a (BindObservable a) where
         key <- newUniqueSTM
         -- Dispose is not instant, so a key is used to disarm the callback derived from the last (now outdated) value
         writeTVar keyVar key
-        disposer <- captureResources_
+        disposer <-
           case lmsg of
             ObservableValue x -> observe (fn x) (rightCallback key)
-            ObservableLoading -> callback ObservableLoading
-            ObservableNotAvailable ex -> callback (ObservableNotAvailable ex)
+            ObservableLoading -> [] <$ callback ObservableLoading
+            ObservableNotAvailable ex -> [] <$ callback (ObservableNotAvailable ex)
         writeTVar disposableVar disposer
         where
           rightCallback :: Unique -> ObservableCallback a
@@ -317,7 +320,7 @@ instance IsObservable a (CatchObservable e a) where
   observe (CatchObservable fx fn) callback = ensureQuasarSTM do
     callback ObservableLoading
     keyVar <- newTVar =<< newUniqueSTM
-    disposableVar <- liftSTM $ newTVar trivialDisposer
+    disposableVar <- liftSTM $ newTVar []
     observe fx (leftCallback keyVar disposableVar)
     where
       leftCallback keyVar disposableVar lmsg = do
@@ -325,10 +328,10 @@ instance IsObservable a (CatchObservable e a) where
         key <- newUniqueSTM
         -- Dispose is not instant, so a key is used to disarm the callback derived from the last (now outdated) value
         writeTVar keyVar key
-        disposer <- captureResources_
+        disposer <-
           case lmsg of
             ObservableNotAvailable (fromException -> Just ex) -> observe (fn ex) (rightCallback key)
-            _ -> callback lmsg
+            _ -> [] <$ callback lmsg
         writeTVar disposableVar disposer
         where
           rightCallback :: Unique -> ObservableCallback a
@@ -340,7 +343,7 @@ instance IsObservable a (CatchObservable e a) where
     pingObservable fx `catch` \ex -> pingObservable (fn ex)
 
 
-newtype ObserverRegistry a = ObserverRegistry (TVar (HM.HashMap Unique (ObservableCallback a)))
+newtype ObserverRegistry a = ObserverRegistry (TVar (HM.HashMap Unique (ObservableState a -> STM ())))
 
 newObserverRegistry :: STM (ObserverRegistry a)
 newObserverRegistry = ObserverRegistry <$> newTVar mempty
@@ -348,17 +351,18 @@ newObserverRegistry = ObserverRegistry <$> newTVar mempty
 newObserverRegistryIO :: MonadIO m => m (ObserverRegistry a)
 newObserverRegistryIO = liftIO $ ObserverRegistry <$> newTVarIO mempty
 
-registerObserver :: ObserverRegistry a -> ObservableCallback a -> ObservableState a -> QuasarSTM ()
+registerObserver :: ObserverRegistry a -> ObservableCallback a -> ObservableState a -> QuasarSTM [Disposer]
 registerObserver (ObserverRegistry var) callback currentState = do
   quasar <- askQuasar
   key <- ensureSTM newUniqueSTM
   ensureSTM $ modifyTVar var (HM.insert key (execForeignQuasarSTM quasar . callback))
-  registerDisposeTransaction_ $ modifyTVar var (HM.delete key)
+  disposer <- registerDisposeTransaction $ modifyTVar var (HM.delete key)
   callback currentState
+  pure [disposer]
 
-updateObservers :: ObserverRegistry a -> ObservableState a -> QuasarSTM ()
+updateObservers :: ObserverRegistry a -> ObservableState a -> STM ()
 updateObservers (ObserverRegistry var) newState =
-  mapM_ ($ newState) . HM.elems =<< ensureSTM (readTVar var)
+  mapM_ ($ newState) . HM.elems =<< readTVar var
 
 
 data ObservableVar a = ObservableVar (TVar a) (ObserverRegistry a)
@@ -378,21 +382,19 @@ newObservableVar x = liftSTM $ ObservableVar <$> newTVar x <*> newObserverRegist
 newObservableVarIO :: MonadIO m => a -> m (ObservableVar a)
 newObservableVarIO x = liftIO $ ObservableVar <$> newTVarIO x <*> newObserverRegistryIO
 
-setObservableVar :: MonadQuasar m => ObservableVar a -> a -> m ()
+setObservableVar :: MonadSTM m => ObservableVar a -> a -> m ()
 setObservableVar var = modifyObservableVar var . const
 
-modifyObservableVar :: MonadQuasar m => ObservableVar a -> (a -> a) -> m ()
+modifyObservableVar :: MonadSTM m => ObservableVar a -> (a -> a) -> m ()
 modifyObservableVar var f = stateObservableVar var (((), ) . f)
 
-stateObservableVar :: MonadQuasar m => ObservableVar a -> (a -> (r, a)) -> m r
-stateObservableVar (ObservableVar var registry) f = ensureQuasarSTM do
-  (result, newValue) <- liftSTM do
+stateObservableVar :: MonadSTM m => ObservableVar a -> (a -> (r, a)) -> m r
+stateObservableVar (ObservableVar var registry) f = liftSTM do
     oldValue <- readTVar var
     let (result, newValue) = f oldValue
     writeTVar var newValue
-    pure (result, newValue)
-  updateObservers registry $ ObservableValue newValue
-  pure result
+    updateObservers registry $ ObservableValue newValue
+    pure result
 
 
 ---- TODO implement
-- 
GitLab