From 962db62b610591e4eea6fb6ef13b816252370773 Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Tue, 13 Jul 2021 18:09:26 +0200
Subject: [PATCH] Implement extendable codegen interface

---
 src/Quasar/Network/Runtime.hs |   4 +-
 src/Quasar/Network/TH.hs      | 354 +++++++++++++++++++++-------------
 2 files changed, 218 insertions(+), 140 deletions(-)

diff --git a/src/Quasar/Network/Runtime.hs b/src/Quasar/Network/Runtime.hs
index 5356d3c..036b3b0 100644
--- a/src/Quasar/Network/Runtime.hs
+++ b/src/Quasar/Network/Runtime.hs
@@ -63,7 +63,7 @@ type ProtocolResponseWrapper p = (MessageId, ProtocolResponse p)
 
 class RpcProtocol p => HasProtocolImpl p where
   type ProtocolImpl p
-  handleMessage :: ProtocolImpl p -> [Channel] -> ProtocolRequest p -> IO (Maybe (ProtocolResponse p))
+  handleRequest :: ProtocolImpl p -> Channel -> ProtocolRequest p -> [Channel] -> IO (Maybe (ProtocolResponse p))
 
 
 data Client p = Client {
@@ -126,7 +126,7 @@ serverHandleChannelMessage protocolImpl channel resources msg = case decodeOrFai
     Right (leftovers, _, _) -> channelReportProtocolError channel ("Request parser pureed unexpected leftovers: " <> show (BSL.length leftovers))
   where
     serverHandleChannelRequest :: [Channel] -> ProtocolRequest p -> IO ()
-    serverHandleChannelRequest channels req = handleMessage @p protocolImpl channels req >>= maybe (pure ()) serverSendResponse
+    serverHandleChannelRequest channels req = handleRequest @p protocolImpl channel req channels >>= maybe (pure ()) serverSendResponse
     serverSendResponse :: ProtocolResponse p -> IO ()
     serverSendResponse response = channelSendSimple channel (encode wrappedResponse)
       where
diff --git a/src/Quasar/Network/TH.hs b/src/Quasar/Network/TH.hs
index a8def53..7a01015 100644
--- a/src/Quasar/Network/TH.hs
+++ b/src/Quasar/Network/TH.hs
@@ -20,7 +20,7 @@ import Control.Monad.State (State, execState)
 import qualified Control.Monad.State as State
 import Data.Binary (Binary)
 import Data.Maybe (isNothing)
-import GHC.Generics
+import GHC.Records.Compat (HasField)
 import Language.Haskell.TH hiding (interruptible)
 import Language.Haskell.TH.Syntax
 import Quasar.Network.Multiplexer
@@ -88,10 +88,12 @@ setFixedHandler handler = State.modify (\fun -> fun{fixedHandler = Just handler}
 
 -- | Generates rpc protocol types, rpc client and rpc server
 makeRpc :: RpcApi -> Q [Dec]
-makeRpc api = mconcat <$> sequence [makeProtocol api, makeClient api, makeServer api]
+makeRpc api = do
+  code <- mconcat <$> sequence (generateFunction api <$> api.functions)
+  mconcat <$> sequence [makeProtocol api code, makeClient code, makeServer api code]
 
-makeProtocol :: RpcApi -> Q [Dec]
-makeProtocol api@RpcApi{functions} = sequence [protocolDec, protocolInstanceDec, messageDec, responseDec]
+makeProtocol :: RpcApi -> Code -> Q [Dec]
+makeProtocol api code = sequence [protocolDec, protocolInstanceDec, requestDec, responseDec]
   where
     protocolDec :: Q Dec
     protocolDec = dataD (pure []) (protocolTypeName api) [] Nothing [] []
@@ -102,34 +104,199 @@ makeProtocol api@RpcApi{functions} = sequence [protocolDec, protocolInstanceDec,
       tySynInstD (tySynEqn Nothing (appT (conT ''ProtocolResponse) (protocolType api)) (conT (responseTypeName api)))
       ]
 
-    messageDec :: Q Dec
-    messageDec = dataD (pure []) (requestTypeName api) [] Nothing (messageCon <$> functions) serializableTypeDerivClauses
+    requestDec :: Q Dec
+    requestDec = dataD (pure []) (requestTypeName api) [] Nothing (requestCon <$> code.requests) serializableTypeDerivClauses
       where
-        messageCon :: RpcFunction -> Q Con
-        messageCon fun = normalC (requestFunctionCtorName api fun) (messageConVar <$> fun.arguments)
-          where
-            messageConVar :: RpcArgument -> Q BangType
-            messageConVar (RpcArgument _name ty) = defaultBangType ty
+        requestCon :: Request -> Q Con
+        requestCon req = normalC (requestConName api req) (defaultBangType . (.ty) <$> req.fields)
 
     responseDec :: Q Dec
-    responseDec = dataD (pure []) (responseTypeName api) [] Nothing (responseCon <$> filter hasResult functions) serializableTypeDerivClauses
+    responseDec = do
+      dataD (pure []) (responseTypeName api) [] Nothing (responseCon <$> catMaybes ((.mResponse) <$> code.requests)) serializableTypeDerivClauses
       where
-        responseCon :: RpcFunction -> Q Con
-        responseCon fun = normalC (responseFunctionCtorName api fun) [defaultBangType (resultTupleType fun)]
-        resultTupleType :: RpcFunction -> Q Type
-        resultTupleType fun = buildTupleType (sequence ((.ty) <$> fun.results))
+        responseCon :: Response -> Q Con
+        responseCon resp = normalC (responseConName api resp) (defaultBangType . (.ty) <$> resp.fields)
 
     serializableTypeDerivClauses :: [Q DerivClause]
     serializableTypeDerivClauses = [
       derivClause Nothing [[t|Eq|], [t|Show|], [t|Generic|], [t|Binary|]]
       ]
 
-makeClient :: RpcApi -> Q [Dec]
-makeClient api@RpcApi{functions} = do
-  mconcat <$> mapM makeClientFunction functions
+makeClient :: Code -> Q [Dec]
+makeClient code = sequence code.stubDecs
+
+makeServer :: RpcApi -> Code -> Q [Dec]
+makeServer api@RpcApi{functions} code = sequence [protocolImplDec, logicInstanceDec]
   where
-    makeClientFunction :: RpcFunction -> Q [Dec]
-    makeClientFunction fun = do
+    protocolImplDec :: Q Dec
+    protocolImplDec = do
+      dataD (pure []) (implTypeName api) [] Nothing [recC (implTypeName api) code.serverImplFields] []
+    functionImplType :: RpcFunction -> Q Type
+    functionImplType fun = do
+      argumentTypes <- functionArgumentTypes fun
+      streamTypes <- serverStreamTypes
+      buildFunctionType (pure (argumentTypes <> streamTypes)) [t|IO $(buildTupleType (functionResultTypes fun))|]
+      where
+        serverStreamTypes :: Q [Type]
+        serverStreamTypes = sequence $ (\stream -> [t|Stream $(stream.tyDown) $(stream.tyUp)|]) <$> fun.streams
+
+    logicInstanceDec :: Q Dec
+    logicInstanceDec = instanceD (cxt []) [t|HasProtocolImpl $(protocolType api)|] [
+      tySynInstD (tySynEqn Nothing [t|ProtocolImpl $(protocolType api)|] (implType api)),
+      requestHandler
+      ]
+    requestHandler :: Q Dec
+    requestHandler = do
+      requestHandlerPrimeName <- newName "handleRequest"
+      implRecordName <- newName "implementation"
+      channelName <- newName "onChannel"
+      funD 'handleRequest [clause [varP implRecordName, varP channelName] (normalB (varE requestHandlerPrimeName)) [funD requestHandlerPrimeName (requestHandlerClauses implRecordName channelName)]]
+      where
+        requestHandlerClauses :: Name -> Name -> [Q Clause]
+        requestHandlerClauses implRecordName channelName = (mconcat $ (requestClauses implRecordName channelName) <$> code.requests)
+        requestClauses :: Name -> Name -> Request -> [Q Clause]
+        requestClauses implRecordName channelName req = [mainClause, invalidChannelCountClause]
+          where
+            mainClause :: Q Clause
+            mainClause = do
+              channelNames <- sequence $ newName . ("channel" <>) . show <$> [0 .. (req.numPipelinedChannels - 1)]
+
+              fieldNames <- sequence $ newName . (.name) <$> req.fields
+              let requestConP = conP (requestConName api req) (varP <$> fieldNames)
+                  ctx = RequestHandlerContext {
+                    implRecordE = varE implRecordName,
+                    argumentEs = (varE <$> fieldNames),
+                    channelEs = (varE <$> channelNames)
+                  }
+
+              clause
+                [requestConP, listP (varP <$> channelNames)]
+                (normalB (packResponse req.mResponse (req.handlerE ctx)))
+                []
+
+            invalidChannelCountClause :: Q Clause
+            invalidChannelCountClause = do
+              channelsName <- newName "newChannels"
+              let requestConP = conP (requestConName api req) (replicate (length req.fields) wildP)
+              clause
+                [requestConP, varP channelsName]
+                (normalB [|$(varE 'reportInvalidChannelCount) $(litE (integerL (toInteger req.numPipelinedChannels))) $(varE channelsName) $(varE channelName)|])
+                []
+
+    packResponse :: Maybe Response -> Q Exp -> Q Exp
+    packResponse Nothing handlerE = [|Nothing <$ $(handlerE)|]
+    packResponse (Just response) handlerE = [|Just . $(conE (responseConName api response)) <$> $handlerE|]
+
+
+-- * Pluggable codegen interface
+
+data Code = Code {
+  stubDecs :: [Q Dec],
+  serverImplFields :: [Q VarBangType],
+  requests :: [Request]
+}
+instance Semigroup Code where
+  x <> y = Code {
+    stubDecs = x.stubDecs <> y.stubDecs,
+    serverImplFields = x.serverImplFields <> y.serverImplFields,
+    requests = x.requests <> y.requests
+  }
+instance Monoid Code where
+  mempty = Code {
+    stubDecs = [],
+    serverImplFields = [],
+    requests = []
+  }
+
+data Request = Request {
+  name :: String,
+  fields :: [Field],
+  numPipelinedChannels :: Int,
+  mResponse :: Maybe Response,
+  handlerE :: RequestHandlerContext -> Q Exp
+}
+
+data Response = Response {
+  name :: String,
+  fields :: [Field]
+  --numCreatedChannels :: Int
+}
+
+data Field = Field {
+  name :: String,
+  ty :: Q Type
+}
+toField :: (HasField "name" a String, HasField "ty" a (Q Type)) => a -> Field
+toField x = Field { name = x.name, ty = x.ty }
+
+data RequestHandlerContext = RequestHandlerContext {
+  implRecordE :: Q Exp,
+  argumentEs :: [Q Exp],
+  channelEs :: [Q Exp]
+}
+
+
+-- * Rpc function code generator
+
+generateFunction :: RpcApi -> RpcFunction -> Q Code
+generateFunction api fun = do
+  stubDecs <- clientFunctionStub
+  pure Code {
+    stubDecs,
+    serverImplFields =
+      if isNothing fun.fixedHandler
+        then [ varDefaultBangType implFieldName implSig ]
+        else [],
+    requests = [request]
+  }
+  where
+    request :: Request
+    request = Request {
+      name = fun.name,
+      fields = toField <$> fun.arguments,
+      numPipelinedChannels = length fun.streams,
+      mResponse = if hasResult fun then Just response else Nothing,
+      handlerE = serverRequestHandlerE
+    }
+    response :: Response
+    response = Response {
+      name = fun.name,
+      -- TODO unpack?
+      fields = [ Field { name = "packedResponse", ty = buildTupleType (sequence ((.ty) <$> fun.results)) } ]
+      --numCreatedChannels = undefined
+    }
+    implFieldName :: Name
+    implFieldName = functionImplFieldName api fun
+    implSig :: Q Type
+    implSig = do
+      argumentTypes <- functionArgumentTypes fun
+      streamTypes <- serverStreamTypes
+      buildFunctionType (pure (argumentTypes <> streamTypes)) [t|IO $(buildTupleType (functionResultTypes fun))|]
+      where
+        serverStreamTypes :: Q [Type]
+        serverStreamTypes = sequence $ (\stream -> [t|Stream $(stream.tyDown) $(stream.tyUp)|]) <$> fun.streams
+
+    serverRequestHandlerE :: RequestHandlerContext -> Q Exp
+    serverRequestHandlerE ctx = applyChannels ctx.channelEs (applyArgs (implFieldE ctx.implRecordE))
+      where
+        implFieldE :: Q Exp -> Q Exp
+        implFieldE implRecordE = case fun.fixedHandler of
+          Nothing -> [|$(varE implFieldName) $implRecordE|]
+          Just handler -> [|$(handler) :: $implSig|]
+        applyArgs :: Q Exp -> Q Exp
+        applyArgs implE = foldl appE implE ctx.argumentEs
+        applyChannels :: [Q Exp] -> Q Exp -> Q Exp
+        applyChannels [] implE = implE
+        applyChannels (channel0E:channelEs) implE = varE 'join `appE` foldl
+          (\x y -> [|$x <*> $y|])
+          ([|$implE <$> $(createStream channel0E)|])
+          (createStream <$> channelEs)
+          where
+            createStream :: Q Exp -> Q Exp
+            createStream = (varE 'newStream `appE`)
+
+    clientFunctionStub :: Q [Q Dec]
+    clientFunctionStub = do
       clientVarName <- newName "client"
       argNames <- sequence (newName . (.name) <$> fun.arguments)
       channelNames <- sequence (newName . (<> "Channel") . (.name) <$> fun.streams)
@@ -138,13 +305,13 @@ makeClient api@RpcApi{functions} = do
       where
         funName :: Name
         funName = mkName fun.name
-        makeClientFunction' :: Name -> [Name] -> [Name] -> [Name] -> Q [Dec]
+        makeClientFunction' :: Name -> [Name] -> [Name] -> [Name] -> Q [Q Dec]
         makeClientFunction' clientVarName argNames channelNames streamNames = do
           funArgTypes <- functionArgumentTypes fun
           clientType <- [t|Client $(protocolType api)|]
           resultType <- optionalResultType
           streamTypes <- clientStreamTypes
-          sequence [
+          pure [
             sigD funName (buildFunctionType (pure ([clientType] <> funArgTypes)) [t|IO $(buildTupleType (pure (resultType <> streamTypes)))|]),
             funD funName [clause ([varP clientVarName] <> varPats) body []]
             ]
@@ -161,10 +328,12 @@ makeClient api@RpcApi{functions} = do
             varPats = varP <$> argNames
             body :: Q Body
             body
-              | hasResult fun = normalB $ doE $
+              | hasResult fun = do
+                responseName <- newName "response"
+                normalB $ doE $
                   [
-                    bindS [p|(response, resources)|] (requestE requestDataE),
-                    bindS [p|result|] (checkResult [|response|])
+                    bindS [p|($(varP responseName), resources)|] (requestE requestDataE),
+                    bindS [p|result|] (checkResult (varE responseName))
                   ] <>
                   createStreams [|resources.createdChannels|] <>
                   [noBindS [|pure $(buildTuple (liftA2 (:) [|result|] streamsE))|]]
@@ -173,7 +342,7 @@ makeClient api@RpcApi{functions} = do
                   createStreams [|resources.createdChannels|] <>
                   [noBindS [|pure $(buildTuple streamsE)|]]
             requestDataE :: Q Exp
-            requestDataE = applyVars (conE (requestFunctionCtorName api fun))
+            requestDataE = applyVars (conE (requestFunctionConName api fun))
             createStreams :: Q Exp -> [Q Stmt]
             createStreams channelsE = if length fun.streams > 0 then [assignChannels] <> go channelNames streamNames else [verifyNoChannels]
               where
@@ -223,111 +392,6 @@ makeClient api@RpcApi{functions} = do
             typedRequest :: Q Exp
             typedRequest = appTypeE (varE 'clientRequestBlocking) (protocolType api)
 
-
-makeServer :: RpcApi -> Q [Dec]
-makeServer api@RpcApi{functions} = sequence [handlerRecordDec, logicInstanceDec]
-  where
-    handlerRecordDec :: Q Dec
-    handlerRecordDec = dataD (pure []) (implTypeName api) [] Nothing [recC (implTypeName api) (handlerRecordField <$> functionsWithoutBuiltinHandler)] []
-    functionsWithoutBuiltinHandler :: [RpcFunction]
-    functionsWithoutBuiltinHandler = filter (isNothing . fixedHandler) functions
-    handlerRecordField :: RpcFunction -> Q VarBangType
-    handlerRecordField fun = varDefaultBangType (implFieldName api fun) (handlerFunctionType fun)
-    handlerFunctionType :: RpcFunction -> Q Type
-    handlerFunctionType fun = do
-      argumentTypes <- functionArgumentTypes fun
-      streamTypes <- serverStreamTypes
-      buildFunctionType (pure (argumentTypes <> streamTypes)) [t|IO $(buildTupleType (functionResultTypes fun))|]
-      where
-        serverStreamTypes :: Q [Type]
-        serverStreamTypes = sequence $ (\stream -> [t|Stream $(stream.tyDown) $(stream.tyUp)|]) <$> fun.streams
-
-    logicInstanceDec :: Q Dec
-    logicInstanceDec = instanceD (cxt []) [t|HasProtocolImpl $(protocolType api)|] [
-      tySynInstD (tySynEqn Nothing [t|ProtocolImpl $(protocolType api)|] (implType api)),
-      messageHandler
-      ]
-    messageHandler :: Q Dec
-    messageHandler = do
-      handleMessagePrimeName <- newName "handleMessage"
-      implName <- newName "impl"
-      channelsName <- newName "channels"
-      funD 'handleMessage [clause [varP implName, varP channelsName] (normalB (varE handleMessagePrimeName)) [handleMessagePrimeDec handleMessagePrimeName (varE implName) (varE channelsName)]]
-      where
-        handleMessagePrimeDec :: Name -> Q Exp -> Q Exp -> Q Dec
-        handleMessagePrimeDec handleMessagePrimeName implE channelsE = funD handleMessagePrimeName (handlerFunctionClause <$> functions)
-          where
-            handlerFunctionClause :: RpcFunction -> Q Clause
-            handlerFunctionClause fun = do
-              argNames <- sequence (newName . (.name) <$> fun.arguments)
-              channelNames <- sequence (newName . (<> "Channel") . (.name) <$> fun.streams)
-              streamNames <- sequence (newName . (.name) <$> fun.streams)
-              serverLogicHandlerFunctionClause' argNames channelNames streamNames
-              where
-                serverLogicHandlerFunctionClause' :: [Name] -> [Name] -> [Name] -> Q Clause
-                serverLogicHandlerFunctionClause' argNames channelNames streamNames = clause [conP (requestFunctionCtorName api fun) varPats] body []
-                  where
-                    varPats :: [Q Pat]
-                    varPats = varP <$> argNames
-                    body :: Q Body
-                    body = normalB $ doE $ createStreams <> [callImplementation]
-                    createStreams :: [Q Stmt]
-                    createStreams = if length fun.streams > 0 then [assignChannels] <> go channelNames streamNames else [verifyNoChannels]
-                      where
-                        verifyNoChannels :: Q Stmt
-                        verifyNoChannels = noBindS [|unless (null $(channelsE)) (fail "Received invalid channel count")|] -- TODO channelReportProtocolError
-                        assignChannels :: Q Stmt
-                        assignChannels =
-                          bindS
-                            (tupP (varP <$> channelNames))
-                            $ caseE channelsE [
-                              match (listP (varP <$> channelNames)) (normalB [|pure $(tupE (varE <$> channelNames))|]) [],
-                              match [p|_|] (normalB [|fail "Received invalid channel count"|]) [] -- TODO channelReportProtocolError
-                              ]
-                        go :: [Name] -> [Name] -> [Q Stmt]
-                        go [] [] = []
-                        go (cn:cns) (sn:sns) = createStream cn sn : go cns sns
-                        go _ _ = fail "Logic error: lists have different lengths"
-                        createStream :: Name -> Name -> Q Stmt
-                        createStream channelName streamName = bindS (varP streamName) [|$(varE 'newStream) $(varE channelName)|]
-                    callImplementation :: Q Stmt
-                    callImplementation = noBindS callImplementationE
-                    callImplementationE :: Q Exp
-                    callImplementationE
-                      | hasResult fun = [|Just <$> $(packResponse (applyStreams (applyArguments implExp)))|]
-                      | otherwise = [|Nothing <$ $(applyStreams (applyArguments implExp))|]
-                    packResponse :: Q Exp -> Q Exp
-                    packResponse = fmapE (conE (responseFunctionCtorName api fun))
-                    applyArguments :: Q Exp -> Q Exp
-                    applyArguments = go argNames
-                      where
-                        go :: [Name] -> Q Exp -> Q Exp
-                        go [] ex = ex
-                        go (n:ns) ex = go ns (appE ex (varE n))
-                    applyStreams :: Q Exp -> Q Exp
-                    applyStreams = go streamNames
-                      where
-                        go :: [Name] -> Q Exp -> Q Exp
-                        go [] ex = ex
-                        go (sn:sns) ex = go sns (appE ex (varE sn))
-                    implExp :: Q Exp
-                    implExp = implExp' fun.fixedHandler
-                      where
-                        implExp' :: Maybe (Q Exp) -> Q Exp
-                        implExp' Nothing = varE (implFieldName api fun) `appE` implE
-                        implExp' (Just handler) = [|
-                          let
-                            impl :: $(implSig)
-                            impl = $(handler)
-                          in impl
-                          |]
-                    implSig :: Q Type
-                    implSig = handlerFunctionType fun
-
--- * Internal
-
--- ** Protocol generator helpers
-
 functionArgumentTypes :: RpcFunction -> Q [Type]
 functionArgumentTypes fun = sequence $ (.ty) <$> fun.arguments
 functionResultTypes :: RpcFunction -> Q [Type]
@@ -337,7 +401,7 @@ hasResult :: RpcFunction -> Bool
 hasResult fun = not (null fun.results)
 
 
--- *** Name helper functions
+-- ** Name helper functions
 
 protocolTypeName :: RpcApi -> Name
 protocolTypeName RpcApi{name} = mkName (name <> "Protocol")
@@ -351,8 +415,11 @@ requestTypeIdentifier RpcApi{name} = name <> "ProtocolRequest"
 requestTypeName :: RpcApi -> Name
 requestTypeName = mkName . requestTypeIdentifier
 
-requestFunctionCtorName :: RpcApi -> RpcFunction -> Name
-requestFunctionCtorName api fun = mkName (requestTypeIdentifier api <> "_" <> fun.name)
+requestFunctionConName :: RpcApi -> RpcFunction -> Name
+requestFunctionConName api fun = mkName (requestTypeIdentifier api <> "_" <> fun.name)
+
+requestConName :: RpcApi -> Request -> Name
+requestConName api req = mkName (requestTypeIdentifier api <> "_" <> req.name)
 
 responseTypeIdentifier :: RpcApi -> String
 responseTypeIdentifier RpcApi{name} = name <> "ProtocolResponse"
@@ -363,16 +430,19 @@ responseTypeName = mkName . responseTypeIdentifier
 responseFunctionCtorName :: RpcApi -> RpcFunction -> Name
 responseFunctionCtorName api fun = mkName (responseTypeIdentifier api <> "_" <> fun.name)
 
+responseConName :: RpcApi -> Response -> Name
+responseConName api resp = mkName (responseTypeIdentifier api <> "_" <> resp.name)
+
 implTypeName :: RpcApi -> Name
 implTypeName RpcApi{name} = mkName $ name <> "ProtocolImpl"
 
 implType :: RpcApi -> Q Type
 implType = conT . implTypeName
 
-implFieldName :: RpcApi -> RpcFunction -> Name
-implFieldName _api fun = mkName (fun.name <> "Impl")
+functionImplFieldName :: RpcApi -> RpcFunction -> Name
+functionImplFieldName _api fun = mkName (fun.name <> "Impl")
 
--- ** Template Haskell helper functions
+-- * Template Haskell helper functions
 
 funT :: Q Type -> Q Type -> Q Type
 funT x = appT (appT arrowT x)
@@ -412,3 +482,11 @@ varDefaultBangType name qType = varBangType name $ bangType (bang noSourceUnpack
 
 fmapE :: Q Exp -> Q Exp -> Q Exp
 fmapE f e = [|$(f) <$> $(e)|]
+
+-- * Error reporting
+
+reportInvalidChannelCount :: Int -> [Channel] -> Channel -> IO a
+reportInvalidChannelCount expectedCount newChannels onChannel = channelReportProtocolError onChannel msg
+  where
+    msg = mconcat parts
+    parts = ["Received ", show (length newChannels), " new channels, but expected ", show expectedCount]
-- 
GitLab