From 9859bde59f7cbb82ca35137f2920cdba624600be Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Wed, 21 Jul 2021 03:36:28 +0200
Subject: [PATCH] Add handler for callback exceptions

Co-authored-by: Jan Beinke <git@janbeinke.com>
---
 quasar.cabal             |  1 +
 src/Quasar/Core.hs       | 86 ++++++++++++++++++++++------------------
 src/Quasar/Observable.hs |  2 +-
 test/Quasar/AsyncSpec.hs | 13 +-----
 4 files changed, 52 insertions(+), 50 deletions(-)

diff --git a/quasar.cabal b/quasar.cabal
index 5214f83..bf523f3 100644
--- a/quasar.cabal
+++ b/quasar.cabal
@@ -64,6 +64,7 @@ library
   build-depends:
     base >=4.7 && <5,
     binary,
+    exceptions,
     ghc-prim,
     hashable,
     microlens-platform,
diff --git a/src/Quasar/Core.hs b/src/Quasar/Core.hs
index 77d4272..f4764d1 100644
--- a/src/Quasar/Core.hs
+++ b/src/Quasar/Core.hs
@@ -11,7 +11,6 @@ module Quasar.Core (
   async,
   await,
   runAsyncIO,
-  startAsyncIO,
 
   -- * AsyncVar
   AsyncVar,
@@ -28,7 +27,7 @@ module Quasar.Core (
   noDisposable,
 ) where
 
-import Control.Exception (try)
+import Control.Monad.Catch
 import Data.HashMap.Strict qualified as HM
 import Quasar.Prelude
 
@@ -39,7 +38,7 @@ class IsAsync r a | a -> r where
   wait :: a -> IO r
   wait x = do
     mvar <- newEmptyMVar
-    onResult_ x (resultCallback mvar)
+    onResult_ (void . tryPutMVar mvar . Left) x (resultCallback mvar)
     readMVar mvar >>= either throwIO pure
     where
       resultCallback :: MVar (Either SomeException r) -> Either SomeException r -> IO ()
@@ -54,14 +53,20 @@ class IsAsync r a | a -> r where
   --
   -- The returned `Disposable` can be used to deregister the callback.
   onResult
-    :: a
+    :: (SomeException -> IO ())
+    -- ^ callback exception handler
+    -> a
     -- ^ async
     -> (Either SomeException r -> IO ())
     -- ^ callback
     -> IO Disposable
 
-  onResult_ :: a -> (Either SomeException r -> IO ()) -> IO ()
-  onResult_ x = void . onResult x
+  onResult_
+    :: (SomeException -> IO ())
+    -> a
+    -> (Either SomeException r -> IO ())
+    -> IO ()
+  onResult_ x y = void . onResult x y
 
   toAsync :: a -> Async r
   toAsync = SomeAsync
@@ -72,8 +77,8 @@ data Async r = forall a. IsAsync r a => SomeAsync a
 
 instance IsAsync r (Async r) where
   wait (SomeAsync x) = wait x
-  onResult (SomeAsync x) = onResult x
-  onResult_ (SomeAsync x) = onResult_ x
+  onResult y (SomeAsync x) = onResult y x
+  onResult_ y (SomeAsync x) = onResult_ y x
   peekAsync (SomeAsync x) = peekAsync x
   toAsync = id
 
@@ -81,7 +86,7 @@ instance IsAsync r (Async r) where
 newtype CompletedAsync r = CompletedAsync (Either SomeException r)
 instance IsAsync r (CompletedAsync r) where
   wait (CompletedAsync value) = either throwIO pure value
-  onResult (CompletedAsync value) callback = noDisposable <$ callback value
+  onResult callbackExceptionHandler (CompletedAsync value) callback = noDisposable <$ (callback value `catch` callbackExceptionHandler)
   peekAsync (CompletedAsync value) = pure $ Just value
 
 completedAsync :: Either SomeException r -> Async r
@@ -106,35 +111,36 @@ instance Applicative AsyncIO where
   (<*>) pf px = pf >>= \f -> f <$> px
   liftA2 f px py = px >>= \x -> f x <$> py
 instance Monad AsyncIO where
-  lhs >>= fn = AsyncIO $ do
-    resultVar <- newAsyncVar
-    lhsAsync <- startAsyncIO lhs
-    lhsAsync `onResult_` \case
-      Right lhsResult -> do
-        rhsAsync <- startAsyncIO $ fn lhsResult
-        rhsAsync `onResult_` putAsyncVarEither resultVar
-      Left lhsEx -> putAsyncVarEither resultVar (Left lhsEx)
-    pure $ toAsync resultVar
+  (>>=) :: forall a b. AsyncIO a -> (a -> AsyncIO b) -> AsyncIO b
+  lhs >>= fn = AsyncIO $ newAsyncVar >>= go
+    where
+      go resultVar = do
+        lhsAsync <- async lhs
+        lhsAsync `onResultBound` \case
+          Right lhsResult ->  do
+            rhsAsync <- async $ fn lhsResult
+            rhsAsync `onResultBound` putAsyncVarEither resultVar
+          Left lhsEx -> putAsyncVarEither resultVar (Left lhsEx)
+        pure $ toAsync resultVar
+        where
+          onResultBound :: forall r. Async r -> (Either SomeException r -> IO ()) -> IO ()
+          onResultBound = onResult_ (putAsyncVarEither resultVar . Left)
 
 instance MonadIO AsyncIO where
-  liftIO = AsyncIO . fmap completedAsync . try
+  liftIO = AsyncIO . fmap successfulAsync
 
 
 -- | Run the synchronous part of an `AsyncIO` and then return an `Async` that can be used to wait for completion of the synchronous part.
-async :: AsyncIO r -> AsyncIO (Async r)
-async = liftIO . startAsyncIO
+async :: MonadIO m => AsyncIO r -> m (Async r)
+async (AsyncIO x) = liftIO x
 
 await :: IsAsync r a => a -> AsyncIO r
 await = AsyncIO . pure . toAsync
 
 -- | Run an `AsyncIO` to completion and return the result.
 runAsyncIO :: AsyncIO r -> IO r
-runAsyncIO = startAsyncIO >=> wait
-
+runAsyncIO = async >=> wait
 
--- | Run the synchronous part of an `AsyncIO`. Returns an `Async` that can be used to wait for completion of the operation.
-startAsyncIO :: AsyncIO r -> IO (Async r)
-startAsyncIO (AsyncIO x) = x
 
 -- ** Forking asyncs
 
@@ -149,7 +155,7 @@ startAsyncIO (AsyncIO x) = x
 
 -- | The default implementation for a `Async` that can be fulfilled later.
 newtype AsyncVar r = AsyncVar (MVar (AsyncVarState r))
-data AsyncVarState r = AsyncVarCompleted (Either SomeException r) | AsyncVarOpen (HM.HashMap Unique (Either SomeException r -> IO ()))
+data AsyncVarState r = AsyncVarCompleted (Either SomeException r) | AsyncVarOpen (HM.HashMap Unique (Either SomeException r -> IO (), SomeException -> IO ()))
 
 instance IsAsync r (AsyncVar r) where
   peekAsync :: AsyncVar r -> IO (Maybe (Either SomeException r))
@@ -157,12 +163,12 @@ instance IsAsync r (AsyncVar r) where
     AsyncVarCompleted x -> Just x
     AsyncVarOpen _ -> Nothing
 
-  onResult :: AsyncVar r -> (Either SomeException r -> IO ()) -> IO Disposable
-  onResult (AsyncVar mvar) callback =
+  onResult :: (SomeException -> IO ()) -> AsyncVar r -> (Either SomeException r -> IO ()) -> IO Disposable
+  onResult callbackExceptionHandler (AsyncVar mvar) callback =
     modifyMVar mvar $ \case
       AsyncVarOpen callbacks -> do
         key <- newUnique
-        pure (AsyncVarOpen (HM.insert key callback callbacks), removeHandler key)
+        pure (AsyncVarOpen (HM.insert key (callback, callbackExceptionHandler) callbacks), removeHandler key)
       x@(AsyncVarCompleted value) -> (x, noDisposable) <$ callback value
     where
       removeHandler :: Unique -> Disposable
@@ -179,13 +185,17 @@ putAsyncVar asyncVar = putAsyncVarEither asyncVar . Right
 
 putAsyncVarEither :: MonadIO m => AsyncVar a -> Either SomeException a -> m ()
 putAsyncVarEither (AsyncVar mvar) value = liftIO $ do
-  modifyMVar_ mvar $ \case
-    AsyncVarCompleted _ -> fail "An AsyncVar can only be fulfilled once"
-    AsyncVarOpen callbacksMap -> do
-      let callbacks = HM.elems callbacksMap
-      -- NOTE disposing a callback while it is called is a deadlock
-      mapM_ ($ value) callbacks
-      pure (AsyncVarCompleted value)
+  mask $ \restore -> do
+    takeMVar mvar >>= \case
+      x@(AsyncVarCompleted _) -> do
+        putMVar mvar x
+        fail "An AsyncVar can only be fulfilled once"
+      AsyncVarOpen callbacksMap -> do
+        let callbacks = HM.elems callbacksMap
+        -- NOTE disposing a callback while it is called is a deadlock
+        forM_ callbacks $ \(callback, callbackExceptionHandler) ->
+          restore (callback value) `catch` callbackExceptionHandler
+        putMVar mvar (AsyncVarCompleted value)
 
 
 -- * Disposable
@@ -205,7 +215,7 @@ disposeIO = runAsyncIO . dispose
 
 -- | Dispose a resource. Returns without waiting for the resource to be released if possible.
 disposeEventually :: IsDisposable a => a -> IO ()
-disposeEventually = void . startAsyncIO . dispose
+disposeEventually = void . async . dispose
 
 instance IsDisposable a => IsDisposable (Maybe a) where
   dispose = mapM_ dispose
diff --git a/src/Quasar/Observable.hs b/src/Quasar/Observable.hs
index 01ff67b..9f7cd42 100644
--- a/src/Quasar/Observable.hs
+++ b/src/Quasar/Observable.hs
@@ -187,7 +187,7 @@ instance forall o i v. (IsObservable i o, IsObservable v i) => IsObservable v (J
         outerCallback :: MVar Disposable -> ObservableMessage i -> IO ()
         outerCallback innerDisposableMVar (_reason, innerObservable) = do
           oldInnerSubscription <- takeMVar innerDisposableMVar
-          void $ startAsyncIO $ do
+          void $ async $ do
             dispose oldInnerSubscription
             liftIO $ do
               newInnerSubscription <- subscribe innerObservable callback
diff --git a/test/Quasar/AsyncSpec.hs b/test/Quasar/AsyncSpec.hs
index 06dd0b6..8c5a1f7 100644
--- a/test/Quasar/AsyncSpec.hs
+++ b/test/Quasar/AsyncSpec.hs
@@ -1,7 +1,7 @@
 module Quasar.AsyncSpec (spec) where
 
-import Control.Applicative (liftA2)
 import Control.Concurrent
+import Control.Exception (throwIO)
 import Control.Monad.IO.Class
 import Data.Either (isRight)
 import Prelude
@@ -26,7 +26,7 @@ spec = parallel $ do
       avar <- newAsyncVar :: IO (AsyncVar ())
 
       mvar <- newEmptyMVar
-      avar `onResult_` putMVar mvar
+      onResult_ throwIO avar (putMVar mvar)
 
       (() <$) <$> tryTakeMVar mvar `shouldReturn` Nothing
 
@@ -46,12 +46,3 @@ spec = parallel $ do
 
     it "can continue after awaiting an already finished operation" $ do
       runAsyncIO (await =<< async (pure 42 :: AsyncIO Int)) `shouldReturn` 42
-
-    --it "can continue after blocking on an async that is completed from another thread" $ do
-    --  a1 <- newAsyncVar
-    --  a2 <- newAsyncVar
-    --  a3 <- newAsyncVar
-    --  a4 <- newAsyncVar
-    --  _ <- forkIO $ runAsyncIO $ await a1 >>= putAsyncVar a2 >> await a3 >>= putAsyncVar a4
-    --  runAsyncIO ((await a2 >> (await a4 *> putAsyncVar a3 1)) *> putAsyncVar a1 41)
-    --  liftA2 (+) (wait a2) (wait a4) `shouldReturn` (42 :: Int)
-- 
GitLab