From f0e8172ed6dc5fd0695b45a446a3e91d7470c7b0 Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Wed, 11 Aug 2021 22:29:14 +0200
Subject: [PATCH] Prepare handler signature for async calls

Co-authored-by: Jan Beinke <git@janbeinke.com>
---
 src/Quasar/Network/Multiplexer.hs        |  4 +--
 src/Quasar/Network/Runtime.hs            | 12 +++++--
 src/Quasar/Network/Runtime/Observable.hs |  4 +--
 src/Quasar/Network/TH.hs                 | 41 +++++++++++-------------
 test/Quasar/NetworkSpec.hs               |  8 ++---
 5 files changed, 36 insertions(+), 33 deletions(-)

diff --git a/src/Quasar/Network/Multiplexer.hs b/src/Quasar/Network/Multiplexer.hs
index 82b388a..66f622a 100644
--- a/src/Quasar/Network/Multiplexer.hs
+++ b/src/Quasar/Network/Multiplexer.hs
@@ -502,8 +502,8 @@ reportProtocolError worker message = do
   multiplexerClose_ ex worker
   throwIO ex
 
-channelReportProtocolError :: Channel -> String -> IO b
-channelReportProtocolError channel message = do
+channelReportProtocolError :: MonadIO m => Channel -> String -> m b
+channelReportProtocolError channel message = liftIO $ do
   let channelId = channel.channelId
   multiplexerSend channel.worker $ ChannelProtocolError channelId message
   let ex = ChannelProtocolException channelId message
diff --git a/src/Quasar/Network/Runtime.hs b/src/Quasar/Network/Runtime.hs
index 6432392..c94ef6e 100644
--- a/src/Quasar/Network/Runtime.hs
+++ b/src/Quasar/Network/Runtime.hs
@@ -48,6 +48,7 @@ import qualified Data.ByteString.Lazy as BSL
 import qualified Data.HashMap.Strict as HM
 import qualified Network.Socket as Socket
 import Quasar.Awaitable
+import Quasar.Core
 import Quasar.Network.Connection
 import Quasar.Network.Multiplexer
 import Quasar.Prelude
@@ -64,7 +65,7 @@ type ProtocolResponseWrapper p = (MessageId, ProtocolResponse p)
 
 class RpcProtocol p => HasProtocolImpl p where
   type ProtocolImpl p
-  handleRequest :: ProtocolImpl p -> Channel -> ProtocolRequest p -> [Channel] -> IO (Maybe (ProtocolResponse p))
+  handleRequest :: HasResourceManager m => ProtocolImpl p -> Channel -> ProtocolRequest p -> [Channel] -> m (Maybe (Task (ProtocolResponse p)))
 
 
 data Client p = Client {
@@ -129,7 +130,14 @@ serverHandleChannelMessage protocolImpl channel resources msg = case decodeOrFai
     Right (leftovers, _, _) -> channelReportProtocolError channel ("Request parser pureed unexpected leftovers: " <> show (BSL.length leftovers))
   where
     serverHandleChannelRequest :: [Channel] -> ProtocolRequest p -> IO ()
-    serverHandleChannelRequest channels req = handleRequest @p protocolImpl channel req channels >>= maybe (pure ()) serverSendResponse
+    serverHandleChannelRequest channels req = do
+      -- TODO resource manager should belong to the current channel/api
+      withDefaultResourceManager $
+        handleRequest @p protocolImpl channel req channels >>= \case
+          Nothing -> pure ()
+          Just task -> do
+            response <- await task
+            liftIO $ serverSendResponse response
     serverSendResponse :: ProtocolResponse p -> IO ()
     serverSendResponse response = channelSendSimple channel (encode wrappedResponse)
       where
diff --git a/src/Quasar/Network/Runtime/Observable.hs b/src/Quasar/Network/Runtime/Observable.hs
index f46eb9d..02f9151 100644
--- a/src/Quasar/Network/Runtime/Observable.hs
+++ b/src/Quasar/Network/Runtime/Observable.hs
@@ -24,8 +24,8 @@ newObservableStub startObserveRequest startRetrieveRequest = pure uncachedObserv
     retrieveFn :: forall m. HasResourceManager m => m (Task v)
     retrieveFn = toTask <$> startRetrieveRequest
 
-observeToStream :: Observable v -> Stream v Void -> IO ()
+observeToStream :: HasResourceManager m => Observable v -> Stream v Void -> m ()
 observeToStream observable stream = do
-  disposable <- observe observable undefined
+  disposable <- liftIO $ observe observable undefined
   -- TODO: dispose when the stream is closed
   undefined
diff --git a/src/Quasar/Network/TH.hs b/src/Quasar/Network/TH.hs
index a57db0f..ca75dd1 100644
--- a/src/Quasar/Network/TH.hs
+++ b/src/Quasar/Network/TH.hs
@@ -246,11 +246,11 @@ makeServer api@RpcApi{functions} code = sequence [protocolImplDec, logicInstance
   where
     protocolImplDec :: Q Dec
     protocolImplDec = do
-      dataD (pure []) (implTypeName api) [] Nothing [recC (implTypeName api) code.serverImplFields] []
+      dataD (pure []) (implRecordTypeName api) [] Nothing [recC (implRecordTypeName api) code.serverImplFields] []
 
     logicInstanceDec :: Q Dec
     logicInstanceDec = instanceD (cxt []) [t|HasProtocolImpl $(protocolType api)|] [
-      tySynInstD (tySynEqn Nothing [t|ProtocolImpl $(protocolType api)|] (implType api)),
+      tySynInstD (tySynEqn Nothing [t|ProtocolImpl $(protocolType api)|] (implRecordType api)),
       requestHandler
       ]
     requestHandler :: Q Dec
@@ -289,16 +289,11 @@ makeServer api@RpcApi{functions} code = sequence [protocolImplDec, logicInstance
                 ]
 
             handlerSig :: Name -> Q Dec
-            handlerSig handlerName = sigD handlerName (buildFunctionType (implResourceTypes req) [t|IO $(resultType)|])
+            handlerSig handlerName = sigD handlerName (buildFunctionType (implResourceTypes req) (implResultType req))
             handlerDec :: Name -> [Name] -> RequestHandlerContext -> Q Dec
             handlerDec handlerName resourceNames ctx = funD handlerName [clause (varP <$> resourceNames) (normalB (req.handlerE ctx)) []]
             applyResources :: [Q Exp] -> Q Exp -> Q Exp
             applyResources resourceEs implE = applyM implE resourceEs
-            resultType :: Q Type
-            resultType =
-              case req.mResponse of
-                Nothing -> [t|()|]
-                Just resp -> [t|$(buildTupleType (sequence ((.ty) <$> resp.fields)))|]
 
             invalidChannelCountClause :: Q Clause
             invalidChannelCountClause = do
@@ -311,7 +306,7 @@ makeServer api@RpcApi{functions} code = sequence [protocolImplDec, logicInstance
 
     packResponse :: Maybe Response -> Q Exp -> Q Exp
     packResponse Nothing handlerE = [|Nothing <$ $(handlerE)|]
-    packResponse (Just response) handlerE = [|Just . $(conE (responseConName api response)) <$> $handlerE|]
+    packResponse (Just response) handlerE = [|Just . fmap $(conE (responseConName api response)) <$> $handlerE|]
 
 
 -- * Pluggable codegen interface
@@ -387,8 +382,7 @@ generateObservable api observable = pure Code {
       fields = [],
       createdResources = [],
       mResponse = Just retrieveResponse,
-      -- TODO use awaitable for result instead of blocking the network thread
-      handlerE = \ctx -> [|withDefaultResourceManager (awaitResult (retrieve $(observableE ctx)))|]
+      handlerE = \ctx -> [|retrieve $(observableE ctx)|]
       }
     retrieveResponse :: Response
     retrieveResponse = Response {
@@ -439,13 +433,7 @@ generateFunction api fun = do
     implFieldName :: Name
     implFieldName = functionImplFieldName api fun
     implSig :: Q Type
-    implSig = do
-      argumentTypes <- functionArgumentTypes fun
-      streamTypes <- serverStreamTypes
-      buildFunctionType (pure (argumentTypes <> streamTypes)) [t|IO $(buildTupleType (functionResultTypes fun))|]
-      where
-        serverStreamTypes :: Q [Type]
-        serverStreamTypes = sequence $ (\stream -> [t|Stream $(stream.tyDown) $(stream.tyUp)|]) <$> fun.streams
+    implSig = buildFunctionType ((functionArgumentTypes fun) <<>> (implResourceTypes request)) (implResultType request)
 
     serverRequestHandlerE :: RequestHandlerContext -> Q Exp
     serverRequestHandlerE ctx = applyResources (applyArgs (implFieldE ctx.implRecordE)) ctx.resourceEs
@@ -513,11 +501,11 @@ responseConName api resp = mkName (responseTypeIdentifier api <> "_" <> resp.nam
 clientType :: RpcApi -> Q Type
 clientType api = [t|Client $(protocolType api)|]
 
-implTypeName :: RpcApi -> Name
-implTypeName RpcApi{name} = mkName $ name <> "ProtocolImpl"
+implRecordTypeName :: RpcApi -> Name
+implRecordTypeName RpcApi{name} = mkName $ name <> "ProtocolImpl"
 
-implType :: RpcApi -> Q Type
-implType = conT . implTypeName
+implRecordType :: RpcApi -> Q Type
+implRecordType = conT . implRecordTypeName
 
 functionImplFieldName :: RpcApi -> RpcFunction -> Name
 functionImplFieldName _api fun = mkName (fun.name <> "Impl")
@@ -564,6 +552,13 @@ createResource :: RequestCreateResource -> Q Exp -> Q Exp
 createResource RequestCreateChannel channelE = [|pure $channelE|]
 createResource (RequestCreateStream up down) channelE = [|newStream $channelE|]
 
+implResultType :: Request -> Q Type
+implResultType req = [t|forall m. HasResourceManager m => m $(resultType)|]
+  where
+    resultType = case req.mResponse of
+      Nothing -> [t|()|]
+      Just resp -> [t|Task $(buildTupleType (sequence ((.ty) <$> resp.fields)))|]
+
 -- * Template Haskell helper functions
 
 funT :: Q Type -> Q Type -> Q Type
@@ -629,7 +624,7 @@ varDefaultBangType name qType = varBangType name $ bangType (bang noSourceUnpack
 
 -- * Error reporting
 
-reportInvalidChannelCount :: Int -> [Channel] -> Channel -> IO a
+reportInvalidChannelCount :: MonadIO m => Int -> [Channel] -> Channel -> m a
 reportInvalidChannelCount expectedCount newChannels onChannel = channelReportProtocolError onChannel msg
   where
     msg = mconcat parts
diff --git a/test/Quasar/NetworkSpec.hs b/test/Quasar/NetworkSpec.hs
index f8bcd58..3b8c887 100644
--- a/test/Quasar/NetworkSpec.hs
+++ b/test/Quasar/NetworkSpec.hs
@@ -30,12 +30,12 @@ $(makeRpc $ rpcApi "Example" $ do
     rpcFunction "fixedHandler42" $ do
       addArgument "arg" [t|Int|]
       addResult "result" [t|Bool|]
-      setFixedHandler [| pure . (== 42) |]
+      setFixedHandler [| pure . pure . (== 42) |]
 
     rpcFunction "fixedHandlerInc" $ do
       addArgument "arg" [t|Int|]
       addResult "result" [t|Int|]
-      setFixedHandler [| pure . (+ 1) |]
+      setFixedHandler [| pure . pure . (+ 1) |]
 
     rpcFunction "multiArgs" $ do
       addArgument "one" [t|Int|]
@@ -68,8 +68,8 @@ $(makeRpc $ rpcApi "ObservableExample" $ do
 
 exampleProtocolImpl :: ExampleProtocolImpl
 exampleProtocolImpl = ExampleProtocolImpl {
-  multiArgsImpl = \one two three -> pure (one + two, not three),
-  noArgsImpl = pure 42,
+  multiArgsImpl = \one two three -> pure $ pure (one + two, not three),
+  noArgsImpl = pure $ pure 42,
   noResponseImpl = \_foo -> pure (),
   noNothingImpl = pure ()
 }
-- 
GitLab