diff --git a/src/Network/Rpc.hs b/src/Network/Rpc.hs index f91ba42c6fe01934ff8cc29aa0ef7b8bb3c3854f..a2918cd563d7b395f68533355ceb5a158a2e2881 100644 --- a/src/Network/Rpc.hs +++ b/src/Network/Rpc.hs @@ -129,12 +129,14 @@ makeClient api@RpcApi{functions} = do makeClientFunction fun = do clientVarName <- newName "client" argNames <- sequence (newName . (.name) <$> fun.arguments) - makeClientFunction' clientVarName argNames + channelNames <- sequence (newName . (<> "Channel") . (.name) <$> fun.streams) + streamNames <- sequence (newName . (.name) <$> fun.streams) + makeClientFunction' clientVarName argNames channelNames streamNames where funName :: Name funName = mkName fun.name - makeClientFunction' :: Name -> [Name] -> Q [Dec] - makeClientFunction' clientVarName argNames = do + makeClientFunction' :: Name -> [Name] -> [Name] -> [Name] -> Q [Dec] + makeClientFunction' clientVarName argNames channelNames streamNames = do funArgTypes <- functionArgumentTypes fun clientType <- [t|Client $(protocolType api)|] resultType <- optionalResultType @@ -156,31 +158,46 @@ makeClient api@RpcApi{functions} = do varPats = varP <$> argNames body :: Q Body body - | hasResult fun = normalB $ [|do - result <- $(checkResult (requestE requestDataE)) - pure $(buildTuple (liftA2 (:) [|result|] streamsE)) - |] - | otherwise = normalB $ [|do - $(sendE requestDataE) - pure $(buildTuple streamsE) - |] + | hasResult fun = normalB $ doE $ + [ + bindS [p|(response, resources)|] (requestE requestDataE), + bindS [p|result|] (checkResult [|response|]) + ] <> + createStreams [|resources.createdChannels|] <> + [noBindS [|pure $(buildTuple (liftA2 (:) [|result|] streamsE))|]] + | otherwise = normalB $ doE $ + [bindS [p|resources|] (sendE requestDataE)] <> + createStreams [|resources.createdChannels|] <> + [noBindS [|pure $(buildTuple streamsE)|]] requestDataE :: Q Exp requestDataE = applyVars (conE (requestFunctionCtorName api fun)) + createStreams :: Q Exp -> [Q Stmt] + createStreams channelsE = if length fun.streams > 0 then [assignChannels] <> go channelNames streamNames else [] + where + assignChannels :: Q Stmt + assignChannels = letS [valD (listP (varP <$> channelNames)) (normalB channelsE) []] + go :: [Name] -> [Name] -> [Q Stmt] + go [] [] = [] + go (cn:cns) (sn:sns) = createStream cn sn : go cns sns + createStream :: Name -> Name -> Q Stmt + createStream channelName streamName = bindS (varP streamName) [|newStream $(varE channelName)|] streamsE :: Q [Exp] - streamsE = mapM (\stream -> [|undefined|]) fun.streams + streamsE = mapM varE streamNames + messageConfigurationE :: Q Exp + messageConfigurationE = [|defaultMessageConfiguration{createChannels = $(litE $ integerL $ toInteger $ length fun.streams)}|] sendE :: Q Exp -> Q Exp - sendE msgExp = [|$typedSend $(clientE) $(msgExp)|] + sendE msgExp = [|$typedSend $clientE $messageConfigurationE $msgExp|] requestE :: Q Exp -> Q Exp - requestE msgExp = [|$typedRequest $(clientE) $(msgExp)|] + requestE msgExp = [|$typedRequest $clientE $messageConfigurationE $msgExp|] applyVars :: Q Exp -> Q Exp applyVars = go argNames where go :: [Name] -> Q Exp -> Q Exp go [] ex = ex go (n:ns) ex = go ns (appE ex (varE n)) - -- check and unbox the response of a request to the result type + -- check if the response to a request matches the expected result constructor checkResult :: Q Exp -> Q Exp - checkResult x = [|$x >>= $(lamCaseE [valid, invalid])|] + checkResult x = caseE x [valid, invalid] where valid :: Q Match valid = do @@ -222,25 +239,44 @@ makeServer api@RpcApi{functions} = sequence [handlerRecordDec, logicInstanceDec] messageHandler = do handleMessagePrimeName <- newName "handleMessage" implName <- newName "impl" - funD 'handleMessage [clause [varP implName] (normalB (varE handleMessagePrimeName)) [handleMessagePrimeDec handleMessagePrimeName implName]] + channelsName <- newName "channels" + funD 'handleMessage [clause [varP implName, varP channelsName] (normalB (varE handleMessagePrimeName)) [handleMessagePrimeDec handleMessagePrimeName (varE implName) (varE channelsName)]] where - handleMessagePrimeDec :: Name -> Name -> Q Dec - handleMessagePrimeDec handleMessagePrimeName implName = funD handleMessagePrimeName (handlerFunctionClause <$> functions) + handleMessagePrimeDec :: Name -> Q Exp -> Q Exp -> Q Dec + handleMessagePrimeDec handleMessagePrimeName implE channelsE = funD handleMessagePrimeName (handlerFunctionClause <$> functions) where handlerFunctionClause :: RpcFunction -> Q Clause handlerFunctionClause fun = do argNames <- sequence (newName . (.name) <$> fun.arguments) - serverLogicHandlerFunctionClause' argNames + channelNames <- sequence (newName . (<> "Channel") . (.name) <$> fun.streams) + streamNames <- sequence (newName . (.name) <$> fun.streams) + serverLogicHandlerFunctionClause' argNames channelNames streamNames where - serverLogicHandlerFunctionClause' :: [Name] -> Q Clause - serverLogicHandlerFunctionClause' argNames = clause [conP (requestFunctionCtorName api fun) varPats] body [] + serverLogicHandlerFunctionClause' :: [Name] -> [Name] -> [Name] -> Q Clause + serverLogicHandlerFunctionClause' argNames channelNames streamNames = clause [conP (requestFunctionCtorName api fun) varPats] body [] where varPats :: [Q Pat] varPats = varP <$> argNames body :: Q Body - body - | hasResult fun = normalB [|fmap Just $(packResponse (applyStreams (applyArguments implExp)))|] - | otherwise = normalB [|Nothing <$ $(applyStreams (applyArguments implExp))|] + body = normalB $ doE $ [verifyChannelCount] <> createStreams <> [callImplementation] + verifyChannelCount :: Q Stmt + verifyChannelCount = noBindS [|when (length $(channelsE) /= $(litE $ integerL $ toInteger $ length fun.streams)) (fail "Received invalid channel count")|] -- TODO channelReportProtocolError + createStreams :: [Q Stmt] + createStreams = if length fun.streams > 0 then [assignChannels] <> go channelNames streamNames else [] + where + assignChannels :: Q Stmt + assignChannels = letS [valD (listP (varP <$> channelNames)) (normalB channelsE) []] + go :: [Name] -> [Name] -> [Q Stmt] + go [] [] = [] + go (cn:cns) (sn:sns) = createStream cn sn : go cns sns + createStream :: Name -> Name -> Q Stmt + createStream channelName streamName = bindS (varP streamName) [|newStream $(varE channelName)|] + callImplementation :: Q Stmt + callImplementation = noBindS callImplementationE + callImplementationE :: Q Exp + callImplementationE + | hasResult fun = [|Just <$> $(packResponse (applyStreams (applyArguments implExp)))|] + | otherwise = [|Nothing <$ $(applyStreams (applyArguments implExp))|] packResponse :: Q Exp -> Q Exp packResponse = fmapE (conE (responseFunctionCtorName api fun)) applyArguments :: Q Exp -> Q Exp @@ -250,16 +286,16 @@ makeServer api@RpcApi{functions} = sequence [handlerRecordDec, logicInstanceDec] go [] ex = ex go (n:ns) ex = go ns (appE ex (varE n)) applyStreams :: Q Exp -> Q Exp - applyStreams = go fun.streams + applyStreams = go fun.streams streamNames where - go :: [RpcStream] -> Q Exp -> Q Exp - go [] ex = ex - go (s:ss) ex = go ss (appE ex [|undefined|]) + go :: [RpcStream] -> [Name] -> Q Exp -> Q Exp + go [] [] ex = ex + go (s:ss) (sn:sns) ex = go ss sns (appE ex (varE sn)) implExp :: Q Exp implExp = implExp' fun.fixedHandler where implExp' :: Maybe (Q Exp) -> Q Exp - implExp' Nothing = varE (implFieldName api fun) `appE` varE implName + implExp' Nothing = varE (implFieldName api fun) `appE` implE implExp' (Just handler) = [| let impl :: $(implSig) @@ -281,7 +317,7 @@ type ProtocolResponseWrapper p = (MessageId, ProtocolResponse p) class RpcProtocol p => HasProtocolImpl p where type ProtocolImpl p - handleMessage :: ProtocolImpl p -> ProtocolRequest p -> IO (Maybe (ProtocolResponse p)) + handleMessage :: ProtocolImpl p -> [Channel] -> ProtocolRequest p -> IO (Maybe (ProtocolResponse p)) data Client p = Client { @@ -296,17 +332,18 @@ emptyClientState = ClientState { callbacks = HM.empty } -clientSend :: RpcProtocol p => Client p -> ProtocolRequest p -> IO () -clientSend client req = void $ channelSend_ client.channel [] (encode req) -clientRequestBlocking :: forall p. RpcProtocol p => Client p -> ProtocolRequest p -> IO (ProtocolResponse p) -clientRequestBlocking client req = do +clientSend :: RpcProtocol p => Client p -> MessageConfiguration -> ProtocolRequest p -> IO SentMessageResources +clientSend client config req = channelSend_ client.channel config (encode req) +clientRequestBlocking :: forall p. RpcProtocol p => Client p -> MessageConfiguration -> ProtocolRequest p -> IO (ProtocolResponse p, SentMessageResources) +clientRequestBlocking client config req = do resultMVar <- newEmptyMVar - void $ channelSend client.channel [] (encode req) $ \msgId -> + sentMessageResources <- channelSend client.channel config (encode req) $ \msgId -> modifyMVar_ client.stateMVar $ \state -> pure state{callbacks = HM.insert msgId (requestCompletedCallback resultMVar msgId) state.callbacks} -- Block on resultMVar until the request completes -- TODO: Future-based variant - takeMVar resultMVar + result <- takeMVar resultMVar + pure (result, sentMessageResources) where requestCompletedCallback :: MVar (ProtocolResponse p) -> MessageId -> ProtocolResponse p -> IO () requestCompletedCallback resultMVar msgId response = do @@ -321,7 +358,7 @@ clientHandleChannelMessage client resources msg = case decodeOrFail msg of where clientHandleResponse :: ProtocolResponseWrapper p -> IO () clientHandleResponse (requestId, resp) = do - mapM_ undefined resources.createdChannels + unless (null resources.createdChannels) (channelReportProtocolError client.channel "Received unexpected new channel during a rpc response") callback <- modifyMVar client.stateMVar $ \state -> do let (callbacks, mCallback) = lookupDelete requestId state.callbacks case mCallback of @@ -339,11 +376,11 @@ clientReportProtocolError client = channelReportProtocolError client.channel serverHandleChannelMessage :: forall p. (RpcProtocol p, HasProtocolImpl p) => ProtocolImpl p -> Channel -> ReceivedMessageResources -> BSL.ByteString -> IO () serverHandleChannelMessage protocolImpl channel resources msg = case decodeOrFail msg of Left (_, _, errMsg) -> channelReportProtocolError channel errMsg - Right ("", _, req) -> serverHandleChannelRequest req + Right ("", _, req) -> serverHandleChannelRequest resources.createdChannels req Right (leftovers, _, _) -> channelReportProtocolError channel ("Request parser pureed unexpected leftovers: " <> show (BSL.length leftovers)) where - serverHandleChannelRequest :: ProtocolRequest p -> IO () - serverHandleChannelRequest req = handleMessage @p protocolImpl req >>= maybe (pure ()) serverSendResponse + serverHandleChannelRequest :: [Channel] -> ProtocolRequest p -> IO () + serverHandleChannelRequest channels req = handleMessage @p protocolImpl channels req >>= maybe (pure ()) serverSendResponse serverSendResponse :: ProtocolResponse p -> IO () serverSendResponse response = channelSendSimple channel (encode wrappedResponse) where @@ -354,16 +391,19 @@ registerChannelServerHandler :: forall p. (RpcProtocol p, HasProtocolImpl p) => registerChannelServerHandler protocolImpl channel = channelSetHandler channel (serverHandleChannelMessage @p protocolImpl channel) -data Stream up down = Stream +newtype Stream up down = Stream Channel + +newStream :: Channel -> IO (Stream up down) +newStream = pure . Stream -streamSend :: Stream up down -> up -> IO () -streamSend = undefined +streamSend :: Binary up => Stream up down -> up -> IO () +streamSend (Stream channel) value = channelSendSimple channel (encode value) -streamSetHandler :: Stream up down -> (down -> IO ()) -> IO () -streamSetHandler = undefined +streamSetHandler :: Binary down => Stream up down -> (down -> IO ()) -> IO () +streamSetHandler (Stream channel) handler = channelSetSimpleHandler channel handler streamClose :: Stream up down -> IO () -streamClose = undefined +streamClose (Stream channel) = channelClose channel -- ** Running client and server @@ -552,7 +592,7 @@ varDefaultBangType :: Name -> Q Type -> Q VarBangType varDefaultBangType name qType = varBangType name $ bangType (bang noSourceUnpackedness noSourceStrictness) qType fmapE :: Q Exp -> Q Exp -> Q Exp -fmapE f = appE (appE (varE 'fmap) f) +fmapE f e = [|$(f) <$> $(e)|] -- ** General helper functions