From f0e8172ed6dc5fd0695b45a446a3e91d7470c7b0 Mon Sep 17 00:00:00 2001 From: Jens Nolte <git@queezle.net> Date: Wed, 11 Aug 2021 22:29:14 +0200 Subject: [PATCH] Prepare handler signature for async calls Co-authored-by: Jan Beinke <git@janbeinke.com> --- src/Quasar/Network/Multiplexer.hs | 4 +-- src/Quasar/Network/Runtime.hs | 12 +++++-- src/Quasar/Network/Runtime/Observable.hs | 4 +-- src/Quasar/Network/TH.hs | 41 +++++++++++------------- test/Quasar/NetworkSpec.hs | 8 ++--- 5 files changed, 36 insertions(+), 33 deletions(-) diff --git a/src/Quasar/Network/Multiplexer.hs b/src/Quasar/Network/Multiplexer.hs index 82b388a..66f622a 100644 --- a/src/Quasar/Network/Multiplexer.hs +++ b/src/Quasar/Network/Multiplexer.hs @@ -502,8 +502,8 @@ reportProtocolError worker message = do multiplexerClose_ ex worker throwIO ex -channelReportProtocolError :: Channel -> String -> IO b -channelReportProtocolError channel message = do +channelReportProtocolError :: MonadIO m => Channel -> String -> m b +channelReportProtocolError channel message = liftIO $ do let channelId = channel.channelId multiplexerSend channel.worker $ ChannelProtocolError channelId message let ex = ChannelProtocolException channelId message diff --git a/src/Quasar/Network/Runtime.hs b/src/Quasar/Network/Runtime.hs index 6432392..c94ef6e 100644 --- a/src/Quasar/Network/Runtime.hs +++ b/src/Quasar/Network/Runtime.hs @@ -48,6 +48,7 @@ import qualified Data.ByteString.Lazy as BSL import qualified Data.HashMap.Strict as HM import qualified Network.Socket as Socket import Quasar.Awaitable +import Quasar.Core import Quasar.Network.Connection import Quasar.Network.Multiplexer import Quasar.Prelude @@ -64,7 +65,7 @@ type ProtocolResponseWrapper p = (MessageId, ProtocolResponse p) class RpcProtocol p => HasProtocolImpl p where type ProtocolImpl p - handleRequest :: ProtocolImpl p -> Channel -> ProtocolRequest p -> [Channel] -> IO (Maybe (ProtocolResponse p)) + handleRequest :: HasResourceManager m => ProtocolImpl p -> Channel -> ProtocolRequest p -> [Channel] -> m (Maybe (Task (ProtocolResponse p))) data Client p = Client { @@ -129,7 +130,14 @@ serverHandleChannelMessage protocolImpl channel resources msg = case decodeOrFai Right (leftovers, _, _) -> channelReportProtocolError channel ("Request parser pureed unexpected leftovers: " <> show (BSL.length leftovers)) where serverHandleChannelRequest :: [Channel] -> ProtocolRequest p -> IO () - serverHandleChannelRequest channels req = handleRequest @p protocolImpl channel req channels >>= maybe (pure ()) serverSendResponse + serverHandleChannelRequest channels req = do + -- TODO resource manager should belong to the current channel/api + withDefaultResourceManager $ + handleRequest @p protocolImpl channel req channels >>= \case + Nothing -> pure () + Just task -> do + response <- await task + liftIO $ serverSendResponse response serverSendResponse :: ProtocolResponse p -> IO () serverSendResponse response = channelSendSimple channel (encode wrappedResponse) where diff --git a/src/Quasar/Network/Runtime/Observable.hs b/src/Quasar/Network/Runtime/Observable.hs index f46eb9d..02f9151 100644 --- a/src/Quasar/Network/Runtime/Observable.hs +++ b/src/Quasar/Network/Runtime/Observable.hs @@ -24,8 +24,8 @@ newObservableStub startObserveRequest startRetrieveRequest = pure uncachedObserv retrieveFn :: forall m. HasResourceManager m => m (Task v) retrieveFn = toTask <$> startRetrieveRequest -observeToStream :: Observable v -> Stream v Void -> IO () +observeToStream :: HasResourceManager m => Observable v -> Stream v Void -> m () observeToStream observable stream = do - disposable <- observe observable undefined + disposable <- liftIO $ observe observable undefined -- TODO: dispose when the stream is closed undefined diff --git a/src/Quasar/Network/TH.hs b/src/Quasar/Network/TH.hs index a57db0f..ca75dd1 100644 --- a/src/Quasar/Network/TH.hs +++ b/src/Quasar/Network/TH.hs @@ -246,11 +246,11 @@ makeServer api@RpcApi{functions} code = sequence [protocolImplDec, logicInstance where protocolImplDec :: Q Dec protocolImplDec = do - dataD (pure []) (implTypeName api) [] Nothing [recC (implTypeName api) code.serverImplFields] [] + dataD (pure []) (implRecordTypeName api) [] Nothing [recC (implRecordTypeName api) code.serverImplFields] [] logicInstanceDec :: Q Dec logicInstanceDec = instanceD (cxt []) [t|HasProtocolImpl $(protocolType api)|] [ - tySynInstD (tySynEqn Nothing [t|ProtocolImpl $(protocolType api)|] (implType api)), + tySynInstD (tySynEqn Nothing [t|ProtocolImpl $(protocolType api)|] (implRecordType api)), requestHandler ] requestHandler :: Q Dec @@ -289,16 +289,11 @@ makeServer api@RpcApi{functions} code = sequence [protocolImplDec, logicInstance ] handlerSig :: Name -> Q Dec - handlerSig handlerName = sigD handlerName (buildFunctionType (implResourceTypes req) [t|IO $(resultType)|]) + handlerSig handlerName = sigD handlerName (buildFunctionType (implResourceTypes req) (implResultType req)) handlerDec :: Name -> [Name] -> RequestHandlerContext -> Q Dec handlerDec handlerName resourceNames ctx = funD handlerName [clause (varP <$> resourceNames) (normalB (req.handlerE ctx)) []] applyResources :: [Q Exp] -> Q Exp -> Q Exp applyResources resourceEs implE = applyM implE resourceEs - resultType :: Q Type - resultType = - case req.mResponse of - Nothing -> [t|()|] - Just resp -> [t|$(buildTupleType (sequence ((.ty) <$> resp.fields)))|] invalidChannelCountClause :: Q Clause invalidChannelCountClause = do @@ -311,7 +306,7 @@ makeServer api@RpcApi{functions} code = sequence [protocolImplDec, logicInstance packResponse :: Maybe Response -> Q Exp -> Q Exp packResponse Nothing handlerE = [|Nothing <$ $(handlerE)|] - packResponse (Just response) handlerE = [|Just . $(conE (responseConName api response)) <$> $handlerE|] + packResponse (Just response) handlerE = [|Just . fmap $(conE (responseConName api response)) <$> $handlerE|] -- * Pluggable codegen interface @@ -387,8 +382,7 @@ generateObservable api observable = pure Code { fields = [], createdResources = [], mResponse = Just retrieveResponse, - -- TODO use awaitable for result instead of blocking the network thread - handlerE = \ctx -> [|withDefaultResourceManager (awaitResult (retrieve $(observableE ctx)))|] + handlerE = \ctx -> [|retrieve $(observableE ctx)|] } retrieveResponse :: Response retrieveResponse = Response { @@ -439,13 +433,7 @@ generateFunction api fun = do implFieldName :: Name implFieldName = functionImplFieldName api fun implSig :: Q Type - implSig = do - argumentTypes <- functionArgumentTypes fun - streamTypes <- serverStreamTypes - buildFunctionType (pure (argumentTypes <> streamTypes)) [t|IO $(buildTupleType (functionResultTypes fun))|] - where - serverStreamTypes :: Q [Type] - serverStreamTypes = sequence $ (\stream -> [t|Stream $(stream.tyDown) $(stream.tyUp)|]) <$> fun.streams + implSig = buildFunctionType ((functionArgumentTypes fun) <<>> (implResourceTypes request)) (implResultType request) serverRequestHandlerE :: RequestHandlerContext -> Q Exp serverRequestHandlerE ctx = applyResources (applyArgs (implFieldE ctx.implRecordE)) ctx.resourceEs @@ -513,11 +501,11 @@ responseConName api resp = mkName (responseTypeIdentifier api <> "_" <> resp.nam clientType :: RpcApi -> Q Type clientType api = [t|Client $(protocolType api)|] -implTypeName :: RpcApi -> Name -implTypeName RpcApi{name} = mkName $ name <> "ProtocolImpl" +implRecordTypeName :: RpcApi -> Name +implRecordTypeName RpcApi{name} = mkName $ name <> "ProtocolImpl" -implType :: RpcApi -> Q Type -implType = conT . implTypeName +implRecordType :: RpcApi -> Q Type +implRecordType = conT . implRecordTypeName functionImplFieldName :: RpcApi -> RpcFunction -> Name functionImplFieldName _api fun = mkName (fun.name <> "Impl") @@ -564,6 +552,13 @@ createResource :: RequestCreateResource -> Q Exp -> Q Exp createResource RequestCreateChannel channelE = [|pure $channelE|] createResource (RequestCreateStream up down) channelE = [|newStream $channelE|] +implResultType :: Request -> Q Type +implResultType req = [t|forall m. HasResourceManager m => m $(resultType)|] + where + resultType = case req.mResponse of + Nothing -> [t|()|] + Just resp -> [t|Task $(buildTupleType (sequence ((.ty) <$> resp.fields)))|] + -- * Template Haskell helper functions funT :: Q Type -> Q Type -> Q Type @@ -629,7 +624,7 @@ varDefaultBangType name qType = varBangType name $ bangType (bang noSourceUnpack -- * Error reporting -reportInvalidChannelCount :: Int -> [Channel] -> Channel -> IO a +reportInvalidChannelCount :: MonadIO m => Int -> [Channel] -> Channel -> m a reportInvalidChannelCount expectedCount newChannels onChannel = channelReportProtocolError onChannel msg where msg = mconcat parts diff --git a/test/Quasar/NetworkSpec.hs b/test/Quasar/NetworkSpec.hs index f8bcd58..3b8c887 100644 --- a/test/Quasar/NetworkSpec.hs +++ b/test/Quasar/NetworkSpec.hs @@ -30,12 +30,12 @@ $(makeRpc $ rpcApi "Example" $ do rpcFunction "fixedHandler42" $ do addArgument "arg" [t|Int|] addResult "result" [t|Bool|] - setFixedHandler [| pure . (== 42) |] + setFixedHandler [| pure . pure . (== 42) |] rpcFunction "fixedHandlerInc" $ do addArgument "arg" [t|Int|] addResult "result" [t|Int|] - setFixedHandler [| pure . (+ 1) |] + setFixedHandler [| pure . pure . (+ 1) |] rpcFunction "multiArgs" $ do addArgument "one" [t|Int|] @@ -68,8 +68,8 @@ $(makeRpc $ rpcApi "ObservableExample" $ do exampleProtocolImpl :: ExampleProtocolImpl exampleProtocolImpl = ExampleProtocolImpl { - multiArgsImpl = \one two three -> pure (one + two, not three), - noArgsImpl = pure 42, + multiArgsImpl = \one two three -> pure $ pure (one + two, not three), + noArgsImpl = pure $ pure 42, noResponseImpl = \_foo -> pure (), noNothingImpl = pure () } -- GitLab