From f52bbd9128687cbcffd8fb2f13c319e4859477ac Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Mon, 21 Jun 2021 17:32:24 +0200
Subject: [PATCH] Implement streams

---
 src/Network/Rpc.hs | 136 +++++++++++++++++++++++++++++----------------
 1 file changed, 88 insertions(+), 48 deletions(-)

diff --git a/src/Network/Rpc.hs b/src/Network/Rpc.hs
index f91ba42..a2918cd 100644
--- a/src/Network/Rpc.hs
+++ b/src/Network/Rpc.hs
@@ -129,12 +129,14 @@ makeClient api@RpcApi{functions} = do
     makeClientFunction fun = do
       clientVarName <- newName "client"
       argNames <- sequence (newName . (.name) <$> fun.arguments)
-      makeClientFunction' clientVarName argNames
+      channelNames <- sequence (newName . (<> "Channel") . (.name) <$> fun.streams)
+      streamNames <- sequence (newName . (.name) <$> fun.streams)
+      makeClientFunction' clientVarName argNames channelNames streamNames
       where
         funName :: Name
         funName = mkName fun.name
-        makeClientFunction' :: Name -> [Name] -> Q [Dec]
-        makeClientFunction' clientVarName argNames = do
+        makeClientFunction' :: Name -> [Name] -> [Name] -> [Name] -> Q [Dec]
+        makeClientFunction' clientVarName argNames channelNames streamNames = do
           funArgTypes <- functionArgumentTypes fun
           clientType <- [t|Client $(protocolType api)|]
           resultType <- optionalResultType
@@ -156,31 +158,46 @@ makeClient api@RpcApi{functions} = do
             varPats = varP <$> argNames
             body :: Q Body
             body
-              | hasResult fun = normalB $ [|do
-                  result <- $(checkResult (requestE requestDataE))
-                  pure $(buildTuple (liftA2 (:) [|result|] streamsE))
-                  |]
-              | otherwise = normalB $ [|do
-                  $(sendE requestDataE)
-                  pure $(buildTuple streamsE)
-                  |]
+              | hasResult fun = normalB $ doE $
+                  [
+                    bindS [p|(response, resources)|] (requestE requestDataE),
+                    bindS [p|result|] (checkResult [|response|])
+                  ] <>
+                  createStreams [|resources.createdChannels|] <>
+                  [noBindS [|pure $(buildTuple (liftA2 (:) [|result|] streamsE))|]]
+              | otherwise = normalB $ doE $
+                  [bindS [p|resources|] (sendE requestDataE)] <>
+                  createStreams [|resources.createdChannels|] <>
+                  [noBindS [|pure $(buildTuple streamsE)|]]
             requestDataE :: Q Exp
             requestDataE = applyVars (conE (requestFunctionCtorName api fun))
+            createStreams :: Q Exp -> [Q Stmt]
+            createStreams channelsE = if length fun.streams > 0 then [assignChannels] <> go channelNames streamNames else []
+              where
+                assignChannels :: Q Stmt
+                assignChannels = letS [valD (listP (varP <$> channelNames)) (normalB channelsE) []]
+                go :: [Name] -> [Name] -> [Q Stmt]
+                go [] [] = []
+                go (cn:cns) (sn:sns) = createStream cn sn : go cns sns
+                createStream :: Name -> Name -> Q Stmt
+                createStream channelName streamName = bindS (varP streamName) [|newStream $(varE channelName)|]
             streamsE :: Q [Exp]
-            streamsE = mapM (\stream -> [|undefined|]) fun.streams
+            streamsE = mapM varE streamNames
+            messageConfigurationE :: Q Exp
+            messageConfigurationE = [|defaultMessageConfiguration{createChannels = $(litE $ integerL $ toInteger $ length fun.streams)}|]
             sendE :: Q Exp -> Q Exp
-            sendE msgExp = [|$typedSend $(clientE) $(msgExp)|]
+            sendE msgExp = [|$typedSend $clientE $messageConfigurationE $msgExp|]
             requestE :: Q Exp -> Q Exp
-            requestE msgExp = [|$typedRequest $(clientE) $(msgExp)|]
+            requestE msgExp = [|$typedRequest $clientE $messageConfigurationE $msgExp|]
             applyVars :: Q Exp -> Q Exp
             applyVars = go argNames
               where
                 go :: [Name] -> Q Exp -> Q Exp
                 go [] ex = ex
                 go (n:ns) ex = go ns (appE ex (varE n))
-            -- check and unbox the response of a request to the result type
+            -- check if the response to a request matches the expected result constructor
             checkResult :: Q Exp -> Q Exp
-            checkResult x = [|$x >>= $(lamCaseE [valid, invalid])|]
+            checkResult x = caseE x [valid, invalid]
               where
                 valid :: Q Match
                 valid = do
@@ -222,25 +239,44 @@ makeServer api@RpcApi{functions} = sequence [handlerRecordDec, logicInstanceDec]
     messageHandler = do
       handleMessagePrimeName <- newName "handleMessage"
       implName <- newName "impl"
-      funD 'handleMessage [clause [varP implName] (normalB (varE handleMessagePrimeName)) [handleMessagePrimeDec handleMessagePrimeName implName]]
+      channelsName <- newName "channels"
+      funD 'handleMessage [clause [varP implName, varP channelsName] (normalB (varE handleMessagePrimeName)) [handleMessagePrimeDec handleMessagePrimeName (varE implName) (varE channelsName)]]
       where
-        handleMessagePrimeDec :: Name -> Name -> Q Dec
-        handleMessagePrimeDec handleMessagePrimeName implName = funD handleMessagePrimeName (handlerFunctionClause <$> functions)
+        handleMessagePrimeDec :: Name -> Q Exp -> Q Exp -> Q Dec
+        handleMessagePrimeDec handleMessagePrimeName implE channelsE = funD handleMessagePrimeName (handlerFunctionClause <$> functions)
           where
             handlerFunctionClause :: RpcFunction -> Q Clause
             handlerFunctionClause fun = do
               argNames <- sequence (newName . (.name) <$> fun.arguments)
-              serverLogicHandlerFunctionClause' argNames
+              channelNames <- sequence (newName . (<> "Channel") . (.name) <$> fun.streams)
+              streamNames <- sequence (newName . (.name) <$> fun.streams)
+              serverLogicHandlerFunctionClause' argNames channelNames streamNames
               where
-                serverLogicHandlerFunctionClause' :: [Name] -> Q Clause
-                serverLogicHandlerFunctionClause' argNames = clause [conP (requestFunctionCtorName api fun) varPats] body []
+                serverLogicHandlerFunctionClause' :: [Name] -> [Name] -> [Name] -> Q Clause
+                serverLogicHandlerFunctionClause' argNames channelNames streamNames = clause [conP (requestFunctionCtorName api fun) varPats] body []
                   where
                     varPats :: [Q Pat]
                     varPats = varP <$> argNames
                     body :: Q Body
-                    body
-                      | hasResult fun = normalB [|fmap Just $(packResponse (applyStreams (applyArguments implExp)))|]
-                      | otherwise = normalB [|Nothing <$ $(applyStreams (applyArguments implExp))|]
+                    body = normalB $ doE $ [verifyChannelCount] <> createStreams <> [callImplementation]
+                    verifyChannelCount :: Q Stmt
+                    verifyChannelCount = noBindS [|when (length $(channelsE) /= $(litE $ integerL $ toInteger $ length fun.streams)) (fail "Received invalid channel count")|] -- TODO channelReportProtocolError
+                    createStreams :: [Q Stmt]
+                    createStreams = if length fun.streams > 0 then [assignChannels] <> go channelNames streamNames else []
+                      where
+                        assignChannels :: Q Stmt
+                        assignChannels = letS [valD (listP (varP <$> channelNames)) (normalB channelsE) []]
+                        go :: [Name] -> [Name] -> [Q Stmt]
+                        go [] [] = []
+                        go (cn:cns) (sn:sns) = createStream cn sn : go cns sns
+                        createStream :: Name -> Name -> Q Stmt
+                        createStream channelName streamName = bindS (varP streamName) [|newStream $(varE channelName)|]
+                    callImplementation :: Q Stmt
+                    callImplementation = noBindS callImplementationE
+                    callImplementationE :: Q Exp
+                    callImplementationE
+                      | hasResult fun = [|Just <$> $(packResponse (applyStreams (applyArguments implExp)))|]
+                      | otherwise = [|Nothing <$ $(applyStreams (applyArguments implExp))|]
                     packResponse :: Q Exp -> Q Exp
                     packResponse = fmapE (conE (responseFunctionCtorName api fun))
                     applyArguments :: Q Exp -> Q Exp
@@ -250,16 +286,16 @@ makeServer api@RpcApi{functions} = sequence [handlerRecordDec, logicInstanceDec]
                         go [] ex = ex
                         go (n:ns) ex = go ns (appE ex (varE n))
                     applyStreams :: Q Exp -> Q Exp
-                    applyStreams = go fun.streams
+                    applyStreams = go fun.streams streamNames
                       where
-                        go :: [RpcStream] -> Q Exp -> Q Exp
-                        go [] ex = ex
-                        go (s:ss) ex = go ss (appE ex [|undefined|])
+                        go :: [RpcStream] -> [Name] -> Q Exp -> Q Exp
+                        go [] [] ex = ex
+                        go (s:ss) (sn:sns) ex = go ss sns (appE ex (varE sn))
                     implExp :: Q Exp
                     implExp = implExp' fun.fixedHandler
                       where
                         implExp' :: Maybe (Q Exp) -> Q Exp
-                        implExp' Nothing = varE (implFieldName api fun) `appE` varE implName
+                        implExp' Nothing = varE (implFieldName api fun) `appE` implE
                         implExp' (Just handler) = [|
                           let
                             impl :: $(implSig)
@@ -281,7 +317,7 @@ type ProtocolResponseWrapper p = (MessageId, ProtocolResponse p)
 
 class RpcProtocol p => HasProtocolImpl p where
   type ProtocolImpl p
-  handleMessage :: ProtocolImpl p -> ProtocolRequest p -> IO (Maybe (ProtocolResponse p))
+  handleMessage :: ProtocolImpl p -> [Channel] -> ProtocolRequest p -> IO (Maybe (ProtocolResponse p))
 
 
 data Client p = Client {
@@ -296,17 +332,18 @@ emptyClientState = ClientState {
   callbacks = HM.empty
 }
 
-clientSend :: RpcProtocol p => Client p -> ProtocolRequest p -> IO ()
-clientSend client req = void $ channelSend_ client.channel [] (encode req)
-clientRequestBlocking :: forall p. RpcProtocol p => Client p -> ProtocolRequest p -> IO (ProtocolResponse p)
-clientRequestBlocking client req = do
+clientSend :: RpcProtocol p => Client p -> MessageConfiguration -> ProtocolRequest p -> IO SentMessageResources
+clientSend client config req = channelSend_ client.channel config (encode req)
+clientRequestBlocking :: forall p. RpcProtocol p => Client p -> MessageConfiguration -> ProtocolRequest p -> IO (ProtocolResponse p, SentMessageResources)
+clientRequestBlocking client config req = do
   resultMVar <- newEmptyMVar
-  void $ channelSend client.channel [] (encode req) $ \msgId ->
+  sentMessageResources <- channelSend client.channel config (encode req) $ \msgId ->
     modifyMVar_ client.stateMVar $
       \state -> pure state{callbacks = HM.insert msgId (requestCompletedCallback resultMVar msgId) state.callbacks}
   -- Block on resultMVar until the request completes
   -- TODO: Future-based variant
-  takeMVar resultMVar
+  result <- takeMVar resultMVar
+  pure (result, sentMessageResources)
   where
     requestCompletedCallback :: MVar (ProtocolResponse p) -> MessageId -> ProtocolResponse p -> IO ()
     requestCompletedCallback resultMVar msgId response = do
@@ -321,7 +358,7 @@ clientHandleChannelMessage client resources msg = case decodeOrFail msg of
   where
     clientHandleResponse :: ProtocolResponseWrapper p -> IO ()
     clientHandleResponse (requestId, resp) = do
-      mapM_ undefined resources.createdChannels
+      unless (null resources.createdChannels) (channelReportProtocolError client.channel "Received unexpected new channel during a rpc response")
       callback <- modifyMVar client.stateMVar $ \state -> do
         let (callbacks, mCallback) = lookupDelete requestId state.callbacks
         case mCallback of
@@ -339,11 +376,11 @@ clientReportProtocolError client = channelReportProtocolError client.channel
 serverHandleChannelMessage :: forall p. (RpcProtocol p, HasProtocolImpl p) => ProtocolImpl p -> Channel -> ReceivedMessageResources -> BSL.ByteString -> IO ()
 serverHandleChannelMessage protocolImpl channel resources msg = case decodeOrFail msg of
     Left (_, _, errMsg) -> channelReportProtocolError channel errMsg
-    Right ("", _, req) -> serverHandleChannelRequest req
+    Right ("", _, req) -> serverHandleChannelRequest resources.createdChannels req
     Right (leftovers, _, _) -> channelReportProtocolError channel ("Request parser pureed unexpected leftovers: " <> show (BSL.length leftovers))
   where
-    serverHandleChannelRequest :: ProtocolRequest p -> IO ()
-    serverHandleChannelRequest req = handleMessage @p protocolImpl req >>= maybe (pure ()) serverSendResponse
+    serverHandleChannelRequest :: [Channel] -> ProtocolRequest p -> IO ()
+    serverHandleChannelRequest channels req = handleMessage @p protocolImpl channels req >>= maybe (pure ()) serverSendResponse
     serverSendResponse :: ProtocolResponse p -> IO ()
     serverSendResponse response = channelSendSimple channel (encode wrappedResponse)
       where
@@ -354,16 +391,19 @@ registerChannelServerHandler :: forall p. (RpcProtocol p, HasProtocolImpl p) =>
 registerChannelServerHandler protocolImpl channel = channelSetHandler channel (serverHandleChannelMessage @p protocolImpl channel)
 
 
-data Stream up down = Stream
+newtype Stream up down = Stream Channel
+
+newStream :: Channel -> IO (Stream up down)
+newStream = pure . Stream
 
-streamSend :: Stream up down -> up -> IO ()
-streamSend = undefined
+streamSend :: Binary up => Stream up down -> up -> IO ()
+streamSend (Stream channel) value = channelSendSimple channel (encode value)
 
-streamSetHandler :: Stream up down -> (down -> IO ()) -> IO ()
-streamSetHandler = undefined
+streamSetHandler :: Binary down => Stream up down -> (down -> IO ()) -> IO ()
+streamSetHandler (Stream channel) handler = channelSetSimpleHandler channel handler
 
 streamClose :: Stream up down -> IO ()
-streamClose = undefined
+streamClose (Stream channel) = channelClose channel
 
 -- ** Running client and server
 
@@ -552,7 +592,7 @@ varDefaultBangType  :: Name -> Q Type -> Q VarBangType
 varDefaultBangType name qType = varBangType name $ bangType (bang noSourceUnpackedness noSourceStrictness) qType
 
 fmapE :: Q Exp -> Q Exp -> Q Exp
-fmapE f = appE (appE (varE 'fmap) f)
+fmapE f e = [|$(f) <$> $(e)|]
 
 -- ** General helper functions
 
-- 
GitLab