Skip to content
Snippets Groups Projects
Commit 962db62b authored by Jens Nolte's avatar Jens Nolte
Browse files

Implement extendable codegen interface

parent cfa5fe3b
No related branches found
No related tags found
No related merge requests found
Pipeline #2303 passed
......@@ -63,7 +63,7 @@ type ProtocolResponseWrapper p = (MessageId, ProtocolResponse p)
class RpcProtocol p => HasProtocolImpl p where
type ProtocolImpl p
handleMessage :: ProtocolImpl p -> [Channel] -> ProtocolRequest p -> IO (Maybe (ProtocolResponse p))
handleRequest :: ProtocolImpl p -> Channel -> ProtocolRequest p -> [Channel] -> IO (Maybe (ProtocolResponse p))
data Client p = Client {
......@@ -126,7 +126,7 @@ serverHandleChannelMessage protocolImpl channel resources msg = case decodeOrFai
Right (leftovers, _, _) -> channelReportProtocolError channel ("Request parser pureed unexpected leftovers: " <> show (BSL.length leftovers))
where
serverHandleChannelRequest :: [Channel] -> ProtocolRequest p -> IO ()
serverHandleChannelRequest channels req = handleMessage @p protocolImpl channels req >>= maybe (pure ()) serverSendResponse
serverHandleChannelRequest channels req = handleRequest @p protocolImpl channel req channels >>= maybe (pure ()) serverSendResponse
serverSendResponse :: ProtocolResponse p -> IO ()
serverSendResponse response = channelSendSimple channel (encode wrappedResponse)
where
......
......@@ -20,7 +20,7 @@ import Control.Monad.State (State, execState)
import qualified Control.Monad.State as State
import Data.Binary (Binary)
import Data.Maybe (isNothing)
import GHC.Generics
import GHC.Records.Compat (HasField)
import Language.Haskell.TH hiding (interruptible)
import Language.Haskell.TH.Syntax
import Quasar.Network.Multiplexer
......@@ -88,10 +88,12 @@ setFixedHandler handler = State.modify (\fun -> fun{fixedHandler = Just handler}
-- | Generates rpc protocol types, rpc client and rpc server
makeRpc :: RpcApi -> Q [Dec]
makeRpc api = mconcat <$> sequence [makeProtocol api, makeClient api, makeServer api]
makeRpc api = do
code <- mconcat <$> sequence (generateFunction api <$> api.functions)
mconcat <$> sequence [makeProtocol api code, makeClient code, makeServer api code]
makeProtocol :: RpcApi -> Q [Dec]
makeProtocol api@RpcApi{functions} = sequence [protocolDec, protocolInstanceDec, messageDec, responseDec]
makeProtocol :: RpcApi -> Code -> Q [Dec]
makeProtocol api code = sequence [protocolDec, protocolInstanceDec, requestDec, responseDec]
where
protocolDec :: Q Dec
protocolDec = dataD (pure []) (protocolTypeName api) [] Nothing [] []
......@@ -102,34 +104,199 @@ makeProtocol api@RpcApi{functions} = sequence [protocolDec, protocolInstanceDec,
tySynInstD (tySynEqn Nothing (appT (conT ''ProtocolResponse) (protocolType api)) (conT (responseTypeName api)))
]
messageDec :: Q Dec
messageDec = dataD (pure []) (requestTypeName api) [] Nothing (messageCon <$> functions) serializableTypeDerivClauses
requestDec :: Q Dec
requestDec = dataD (pure []) (requestTypeName api) [] Nothing (requestCon <$> code.requests) serializableTypeDerivClauses
where
messageCon :: RpcFunction -> Q Con
messageCon fun = normalC (requestFunctionCtorName api fun) (messageConVar <$> fun.arguments)
where
messageConVar :: RpcArgument -> Q BangType
messageConVar (RpcArgument _name ty) = defaultBangType ty
requestCon :: Request -> Q Con
requestCon req = normalC (requestConName api req) (defaultBangType . (.ty) <$> req.fields)
responseDec :: Q Dec
responseDec = dataD (pure []) (responseTypeName api) [] Nothing (responseCon <$> filter hasResult functions) serializableTypeDerivClauses
responseDec = do
dataD (pure []) (responseTypeName api) [] Nothing (responseCon <$> catMaybes ((.mResponse) <$> code.requests)) serializableTypeDerivClauses
where
responseCon :: RpcFunction -> Q Con
responseCon fun = normalC (responseFunctionCtorName api fun) [defaultBangType (resultTupleType fun)]
resultTupleType :: RpcFunction -> Q Type
resultTupleType fun = buildTupleType (sequence ((.ty) <$> fun.results))
responseCon :: Response -> Q Con
responseCon resp = normalC (responseConName api resp) (defaultBangType . (.ty) <$> resp.fields)
serializableTypeDerivClauses :: [Q DerivClause]
serializableTypeDerivClauses = [
derivClause Nothing [[t|Eq|], [t|Show|], [t|Generic|], [t|Binary|]]
]
makeClient :: RpcApi -> Q [Dec]
makeClient api@RpcApi{functions} = do
mconcat <$> mapM makeClientFunction functions
makeClient :: Code -> Q [Dec]
makeClient code = sequence code.stubDecs
makeServer :: RpcApi -> Code -> Q [Dec]
makeServer api@RpcApi{functions} code = sequence [protocolImplDec, logicInstanceDec]
where
makeClientFunction :: RpcFunction -> Q [Dec]
makeClientFunction fun = do
protocolImplDec :: Q Dec
protocolImplDec = do
dataD (pure []) (implTypeName api) [] Nothing [recC (implTypeName api) code.serverImplFields] []
functionImplType :: RpcFunction -> Q Type
functionImplType fun = do
argumentTypes <- functionArgumentTypes fun
streamTypes <- serverStreamTypes
buildFunctionType (pure (argumentTypes <> streamTypes)) [t|IO $(buildTupleType (functionResultTypes fun))|]
where
serverStreamTypes :: Q [Type]
serverStreamTypes = sequence $ (\stream -> [t|Stream $(stream.tyDown) $(stream.tyUp)|]) <$> fun.streams
logicInstanceDec :: Q Dec
logicInstanceDec = instanceD (cxt []) [t|HasProtocolImpl $(protocolType api)|] [
tySynInstD (tySynEqn Nothing [t|ProtocolImpl $(protocolType api)|] (implType api)),
requestHandler
]
requestHandler :: Q Dec
requestHandler = do
requestHandlerPrimeName <- newName "handleRequest"
implRecordName <- newName "implementation"
channelName <- newName "onChannel"
funD 'handleRequest [clause [varP implRecordName, varP channelName] (normalB (varE requestHandlerPrimeName)) [funD requestHandlerPrimeName (requestHandlerClauses implRecordName channelName)]]
where
requestHandlerClauses :: Name -> Name -> [Q Clause]
requestHandlerClauses implRecordName channelName = (mconcat $ (requestClauses implRecordName channelName) <$> code.requests)
requestClauses :: Name -> Name -> Request -> [Q Clause]
requestClauses implRecordName channelName req = [mainClause, invalidChannelCountClause]
where
mainClause :: Q Clause
mainClause = do
channelNames <- sequence $ newName . ("channel" <>) . show <$> [0 .. (req.numPipelinedChannels - 1)]
fieldNames <- sequence $ newName . (.name) <$> req.fields
let requestConP = conP (requestConName api req) (varP <$> fieldNames)
ctx = RequestHandlerContext {
implRecordE = varE implRecordName,
argumentEs = (varE <$> fieldNames),
channelEs = (varE <$> channelNames)
}
clause
[requestConP, listP (varP <$> channelNames)]
(normalB (packResponse req.mResponse (req.handlerE ctx)))
[]
invalidChannelCountClause :: Q Clause
invalidChannelCountClause = do
channelsName <- newName "newChannels"
let requestConP = conP (requestConName api req) (replicate (length req.fields) wildP)
clause
[requestConP, varP channelsName]
(normalB [|$(varE 'reportInvalidChannelCount) $(litE (integerL (toInteger req.numPipelinedChannels))) $(varE channelsName) $(varE channelName)|])
[]
packResponse :: Maybe Response -> Q Exp -> Q Exp
packResponse Nothing handlerE = [|Nothing <$ $(handlerE)|]
packResponse (Just response) handlerE = [|Just . $(conE (responseConName api response)) <$> $handlerE|]
-- * Pluggable codegen interface
data Code = Code {
stubDecs :: [Q Dec],
serverImplFields :: [Q VarBangType],
requests :: [Request]
}
instance Semigroup Code where
x <> y = Code {
stubDecs = x.stubDecs <> y.stubDecs,
serverImplFields = x.serverImplFields <> y.serverImplFields,
requests = x.requests <> y.requests
}
instance Monoid Code where
mempty = Code {
stubDecs = [],
serverImplFields = [],
requests = []
}
data Request = Request {
name :: String,
fields :: [Field],
numPipelinedChannels :: Int,
mResponse :: Maybe Response,
handlerE :: RequestHandlerContext -> Q Exp
}
data Response = Response {
name :: String,
fields :: [Field]
--numCreatedChannels :: Int
}
data Field = Field {
name :: String,
ty :: Q Type
}
toField :: (HasField "name" a String, HasField "ty" a (Q Type)) => a -> Field
toField x = Field { name = x.name, ty = x.ty }
data RequestHandlerContext = RequestHandlerContext {
implRecordE :: Q Exp,
argumentEs :: [Q Exp],
channelEs :: [Q Exp]
}
-- * Rpc function code generator
generateFunction :: RpcApi -> RpcFunction -> Q Code
generateFunction api fun = do
stubDecs <- clientFunctionStub
pure Code {
stubDecs,
serverImplFields =
if isNothing fun.fixedHandler
then [ varDefaultBangType implFieldName implSig ]
else [],
requests = [request]
}
where
request :: Request
request = Request {
name = fun.name,
fields = toField <$> fun.arguments,
numPipelinedChannels = length fun.streams,
mResponse = if hasResult fun then Just response else Nothing,
handlerE = serverRequestHandlerE
}
response :: Response
response = Response {
name = fun.name,
-- TODO unpack?
fields = [ Field { name = "packedResponse", ty = buildTupleType (sequence ((.ty) <$> fun.results)) } ]
--numCreatedChannels = undefined
}
implFieldName :: Name
implFieldName = functionImplFieldName api fun
implSig :: Q Type
implSig = do
argumentTypes <- functionArgumentTypes fun
streamTypes <- serverStreamTypes
buildFunctionType (pure (argumentTypes <> streamTypes)) [t|IO $(buildTupleType (functionResultTypes fun))|]
where
serverStreamTypes :: Q [Type]
serverStreamTypes = sequence $ (\stream -> [t|Stream $(stream.tyDown) $(stream.tyUp)|]) <$> fun.streams
serverRequestHandlerE :: RequestHandlerContext -> Q Exp
serverRequestHandlerE ctx = applyChannels ctx.channelEs (applyArgs (implFieldE ctx.implRecordE))
where
implFieldE :: Q Exp -> Q Exp
implFieldE implRecordE = case fun.fixedHandler of
Nothing -> [|$(varE implFieldName) $implRecordE|]
Just handler -> [|$(handler) :: $implSig|]
applyArgs :: Q Exp -> Q Exp
applyArgs implE = foldl appE implE ctx.argumentEs
applyChannels :: [Q Exp] -> Q Exp -> Q Exp
applyChannels [] implE = implE
applyChannels (channel0E:channelEs) implE = varE 'join `appE` foldl
(\x y -> [|$x <*> $y|])
([|$implE <$> $(createStream channel0E)|])
(createStream <$> channelEs)
where
createStream :: Q Exp -> Q Exp
createStream = (varE 'newStream `appE`)
clientFunctionStub :: Q [Q Dec]
clientFunctionStub = do
clientVarName <- newName "client"
argNames <- sequence (newName . (.name) <$> fun.arguments)
channelNames <- sequence (newName . (<> "Channel") . (.name) <$> fun.streams)
......@@ -138,13 +305,13 @@ makeClient api@RpcApi{functions} = do
where
funName :: Name
funName = mkName fun.name
makeClientFunction' :: Name -> [Name] -> [Name] -> [Name] -> Q [Dec]
makeClientFunction' :: Name -> [Name] -> [Name] -> [Name] -> Q [Q Dec]
makeClientFunction' clientVarName argNames channelNames streamNames = do
funArgTypes <- functionArgumentTypes fun
clientType <- [t|Client $(protocolType api)|]
resultType <- optionalResultType
streamTypes <- clientStreamTypes
sequence [
pure [
sigD funName (buildFunctionType (pure ([clientType] <> funArgTypes)) [t|IO $(buildTupleType (pure (resultType <> streamTypes)))|]),
funD funName [clause ([varP clientVarName] <> varPats) body []]
]
......@@ -161,10 +328,12 @@ makeClient api@RpcApi{functions} = do
varPats = varP <$> argNames
body :: Q Body
body
| hasResult fun = normalB $ doE $
| hasResult fun = do
responseName <- newName "response"
normalB $ doE $
[
bindS [p|(response, resources)|] (requestE requestDataE),
bindS [p|result|] (checkResult [|response|])
bindS [p|($(varP responseName), resources)|] (requestE requestDataE),
bindS [p|result|] (checkResult (varE responseName))
] <>
createStreams [|resources.createdChannels|] <>
[noBindS [|pure $(buildTuple (liftA2 (:) [|result|] streamsE))|]]
......@@ -173,7 +342,7 @@ makeClient api@RpcApi{functions} = do
createStreams [|resources.createdChannels|] <>
[noBindS [|pure $(buildTuple streamsE)|]]
requestDataE :: Q Exp
requestDataE = applyVars (conE (requestFunctionCtorName api fun))
requestDataE = applyVars (conE (requestFunctionConName api fun))
createStreams :: Q Exp -> [Q Stmt]
createStreams channelsE = if length fun.streams > 0 then [assignChannels] <> go channelNames streamNames else [verifyNoChannels]
where
......@@ -223,111 +392,6 @@ makeClient api@RpcApi{functions} = do
typedRequest :: Q Exp
typedRequest = appTypeE (varE 'clientRequestBlocking) (protocolType api)
makeServer :: RpcApi -> Q [Dec]
makeServer api@RpcApi{functions} = sequence [handlerRecordDec, logicInstanceDec]
where
handlerRecordDec :: Q Dec
handlerRecordDec = dataD (pure []) (implTypeName api) [] Nothing [recC (implTypeName api) (handlerRecordField <$> functionsWithoutBuiltinHandler)] []
functionsWithoutBuiltinHandler :: [RpcFunction]
functionsWithoutBuiltinHandler = filter (isNothing . fixedHandler) functions
handlerRecordField :: RpcFunction -> Q VarBangType
handlerRecordField fun = varDefaultBangType (implFieldName api fun) (handlerFunctionType fun)
handlerFunctionType :: RpcFunction -> Q Type
handlerFunctionType fun = do
argumentTypes <- functionArgumentTypes fun
streamTypes <- serverStreamTypes
buildFunctionType (pure (argumentTypes <> streamTypes)) [t|IO $(buildTupleType (functionResultTypes fun))|]
where
serverStreamTypes :: Q [Type]
serverStreamTypes = sequence $ (\stream -> [t|Stream $(stream.tyDown) $(stream.tyUp)|]) <$> fun.streams
logicInstanceDec :: Q Dec
logicInstanceDec = instanceD (cxt []) [t|HasProtocolImpl $(protocolType api)|] [
tySynInstD (tySynEqn Nothing [t|ProtocolImpl $(protocolType api)|] (implType api)),
messageHandler
]
messageHandler :: Q Dec
messageHandler = do
handleMessagePrimeName <- newName "handleMessage"
implName <- newName "impl"
channelsName <- newName "channels"
funD 'handleMessage [clause [varP implName, varP channelsName] (normalB (varE handleMessagePrimeName)) [handleMessagePrimeDec handleMessagePrimeName (varE implName) (varE channelsName)]]
where
handleMessagePrimeDec :: Name -> Q Exp -> Q Exp -> Q Dec
handleMessagePrimeDec handleMessagePrimeName implE channelsE = funD handleMessagePrimeName (handlerFunctionClause <$> functions)
where
handlerFunctionClause :: RpcFunction -> Q Clause
handlerFunctionClause fun = do
argNames <- sequence (newName . (.name) <$> fun.arguments)
channelNames <- sequence (newName . (<> "Channel") . (.name) <$> fun.streams)
streamNames <- sequence (newName . (.name) <$> fun.streams)
serverLogicHandlerFunctionClause' argNames channelNames streamNames
where
serverLogicHandlerFunctionClause' :: [Name] -> [Name] -> [Name] -> Q Clause
serverLogicHandlerFunctionClause' argNames channelNames streamNames = clause [conP (requestFunctionCtorName api fun) varPats] body []
where
varPats :: [Q Pat]
varPats = varP <$> argNames
body :: Q Body
body = normalB $ doE $ createStreams <> [callImplementation]
createStreams :: [Q Stmt]
createStreams = if length fun.streams > 0 then [assignChannels] <> go channelNames streamNames else [verifyNoChannels]
where
verifyNoChannels :: Q Stmt
verifyNoChannels = noBindS [|unless (null $(channelsE)) (fail "Received invalid channel count")|] -- TODO channelReportProtocolError
assignChannels :: Q Stmt
assignChannels =
bindS
(tupP (varP <$> channelNames))
$ caseE channelsE [
match (listP (varP <$> channelNames)) (normalB [|pure $(tupE (varE <$> channelNames))|]) [],
match [p|_|] (normalB [|fail "Received invalid channel count"|]) [] -- TODO channelReportProtocolError
]
go :: [Name] -> [Name] -> [Q Stmt]
go [] [] = []
go (cn:cns) (sn:sns) = createStream cn sn : go cns sns
go _ _ = fail "Logic error: lists have different lengths"
createStream :: Name -> Name -> Q Stmt
createStream channelName streamName = bindS (varP streamName) [|$(varE 'newStream) $(varE channelName)|]
callImplementation :: Q Stmt
callImplementation = noBindS callImplementationE
callImplementationE :: Q Exp
callImplementationE
| hasResult fun = [|Just <$> $(packResponse (applyStreams (applyArguments implExp)))|]
| otherwise = [|Nothing <$ $(applyStreams (applyArguments implExp))|]
packResponse :: Q Exp -> Q Exp
packResponse = fmapE (conE (responseFunctionCtorName api fun))
applyArguments :: Q Exp -> Q Exp
applyArguments = go argNames
where
go :: [Name] -> Q Exp -> Q Exp
go [] ex = ex
go (n:ns) ex = go ns (appE ex (varE n))
applyStreams :: Q Exp -> Q Exp
applyStreams = go streamNames
where
go :: [Name] -> Q Exp -> Q Exp
go [] ex = ex
go (sn:sns) ex = go sns (appE ex (varE sn))
implExp :: Q Exp
implExp = implExp' fun.fixedHandler
where
implExp' :: Maybe (Q Exp) -> Q Exp
implExp' Nothing = varE (implFieldName api fun) `appE` implE
implExp' (Just handler) = [|
let
impl :: $(implSig)
impl = $(handler)
in impl
|]
implSig :: Q Type
implSig = handlerFunctionType fun
-- * Internal
-- ** Protocol generator helpers
functionArgumentTypes :: RpcFunction -> Q [Type]
functionArgumentTypes fun = sequence $ (.ty) <$> fun.arguments
functionResultTypes :: RpcFunction -> Q [Type]
......@@ -337,7 +401,7 @@ hasResult :: RpcFunction -> Bool
hasResult fun = not (null fun.results)
-- *** Name helper functions
-- ** Name helper functions
protocolTypeName :: RpcApi -> Name
protocolTypeName RpcApi{name} = mkName (name <> "Protocol")
......@@ -351,8 +415,11 @@ requestTypeIdentifier RpcApi{name} = name <> "ProtocolRequest"
requestTypeName :: RpcApi -> Name
requestTypeName = mkName . requestTypeIdentifier
requestFunctionCtorName :: RpcApi -> RpcFunction -> Name
requestFunctionCtorName api fun = mkName (requestTypeIdentifier api <> "_" <> fun.name)
requestFunctionConName :: RpcApi -> RpcFunction -> Name
requestFunctionConName api fun = mkName (requestTypeIdentifier api <> "_" <> fun.name)
requestConName :: RpcApi -> Request -> Name
requestConName api req = mkName (requestTypeIdentifier api <> "_" <> req.name)
responseTypeIdentifier :: RpcApi -> String
responseTypeIdentifier RpcApi{name} = name <> "ProtocolResponse"
......@@ -363,16 +430,19 @@ responseTypeName = mkName . responseTypeIdentifier
responseFunctionCtorName :: RpcApi -> RpcFunction -> Name
responseFunctionCtorName api fun = mkName (responseTypeIdentifier api <> "_" <> fun.name)
responseConName :: RpcApi -> Response -> Name
responseConName api resp = mkName (responseTypeIdentifier api <> "_" <> resp.name)
implTypeName :: RpcApi -> Name
implTypeName RpcApi{name} = mkName $ name <> "ProtocolImpl"
implType :: RpcApi -> Q Type
implType = conT . implTypeName
implFieldName :: RpcApi -> RpcFunction -> Name
implFieldName _api fun = mkName (fun.name <> "Impl")
functionImplFieldName :: RpcApi -> RpcFunction -> Name
functionImplFieldName _api fun = mkName (fun.name <> "Impl")
-- ** Template Haskell helper functions
-- * Template Haskell helper functions
funT :: Q Type -> Q Type -> Q Type
funT x = appT (appT arrowT x)
......@@ -412,3 +482,11 @@ varDefaultBangType name qType = varBangType name $ bangType (bang noSourceUnpack
fmapE :: Q Exp -> Q Exp -> Q Exp
fmapE f e = [|$(f) <$> $(e)|]
-- * Error reporting
reportInvalidChannelCount :: Int -> [Channel] -> Channel -> IO a
reportInvalidChannelCount expectedCount newChannels onChannel = channelReportProtocolError onChannel msg
where
msg = mconcat parts
parts = ["Received ", show (length newChannels), " new channels, but expected ", show expectedCount]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment