diff --git a/src/Quasar/Network/Runtime.hs b/src/Quasar/Network/Runtime.hs index 5356d3c3df572daba7898ca08a7390257fe951cf..036b3b0a6c7ce3e53cd06e01b7ca494f8fb3a21c 100644 --- a/src/Quasar/Network/Runtime.hs +++ b/src/Quasar/Network/Runtime.hs @@ -63,7 +63,7 @@ type ProtocolResponseWrapper p = (MessageId, ProtocolResponse p) class RpcProtocol p => HasProtocolImpl p where type ProtocolImpl p - handleMessage :: ProtocolImpl p -> [Channel] -> ProtocolRequest p -> IO (Maybe (ProtocolResponse p)) + handleRequest :: ProtocolImpl p -> Channel -> ProtocolRequest p -> [Channel] -> IO (Maybe (ProtocolResponse p)) data Client p = Client { @@ -126,7 +126,7 @@ serverHandleChannelMessage protocolImpl channel resources msg = case decodeOrFai Right (leftovers, _, _) -> channelReportProtocolError channel ("Request parser pureed unexpected leftovers: " <> show (BSL.length leftovers)) where serverHandleChannelRequest :: [Channel] -> ProtocolRequest p -> IO () - serverHandleChannelRequest channels req = handleMessage @p protocolImpl channels req >>= maybe (pure ()) serverSendResponse + serverHandleChannelRequest channels req = handleRequest @p protocolImpl channel req channels >>= maybe (pure ()) serverSendResponse serverSendResponse :: ProtocolResponse p -> IO () serverSendResponse response = channelSendSimple channel (encode wrappedResponse) where diff --git a/src/Quasar/Network/TH.hs b/src/Quasar/Network/TH.hs index a8def53724949d71b235215798c8127dc3ef3a45..7a0101548f0e0a94519b053412859cb6b3d82d45 100644 --- a/src/Quasar/Network/TH.hs +++ b/src/Quasar/Network/TH.hs @@ -20,7 +20,7 @@ import Control.Monad.State (State, execState) import qualified Control.Monad.State as State import Data.Binary (Binary) import Data.Maybe (isNothing) -import GHC.Generics +import GHC.Records.Compat (HasField) import Language.Haskell.TH hiding (interruptible) import Language.Haskell.TH.Syntax import Quasar.Network.Multiplexer @@ -88,10 +88,12 @@ setFixedHandler handler = State.modify (\fun -> fun{fixedHandler = Just handler} -- | Generates rpc protocol types, rpc client and rpc server makeRpc :: RpcApi -> Q [Dec] -makeRpc api = mconcat <$> sequence [makeProtocol api, makeClient api, makeServer api] +makeRpc api = do + code <- mconcat <$> sequence (generateFunction api <$> api.functions) + mconcat <$> sequence [makeProtocol api code, makeClient code, makeServer api code] -makeProtocol :: RpcApi -> Q [Dec] -makeProtocol api@RpcApi{functions} = sequence [protocolDec, protocolInstanceDec, messageDec, responseDec] +makeProtocol :: RpcApi -> Code -> Q [Dec] +makeProtocol api code = sequence [protocolDec, protocolInstanceDec, requestDec, responseDec] where protocolDec :: Q Dec protocolDec = dataD (pure []) (protocolTypeName api) [] Nothing [] [] @@ -102,34 +104,199 @@ makeProtocol api@RpcApi{functions} = sequence [protocolDec, protocolInstanceDec, tySynInstD (tySynEqn Nothing (appT (conT ''ProtocolResponse) (protocolType api)) (conT (responseTypeName api))) ] - messageDec :: Q Dec - messageDec = dataD (pure []) (requestTypeName api) [] Nothing (messageCon <$> functions) serializableTypeDerivClauses + requestDec :: Q Dec + requestDec = dataD (pure []) (requestTypeName api) [] Nothing (requestCon <$> code.requests) serializableTypeDerivClauses where - messageCon :: RpcFunction -> Q Con - messageCon fun = normalC (requestFunctionCtorName api fun) (messageConVar <$> fun.arguments) - where - messageConVar :: RpcArgument -> Q BangType - messageConVar (RpcArgument _name ty) = defaultBangType ty + requestCon :: Request -> Q Con + requestCon req = normalC (requestConName api req) (defaultBangType . (.ty) <$> req.fields) responseDec :: Q Dec - responseDec = dataD (pure []) (responseTypeName api) [] Nothing (responseCon <$> filter hasResult functions) serializableTypeDerivClauses + responseDec = do + dataD (pure []) (responseTypeName api) [] Nothing (responseCon <$> catMaybes ((.mResponse) <$> code.requests)) serializableTypeDerivClauses where - responseCon :: RpcFunction -> Q Con - responseCon fun = normalC (responseFunctionCtorName api fun) [defaultBangType (resultTupleType fun)] - resultTupleType :: RpcFunction -> Q Type - resultTupleType fun = buildTupleType (sequence ((.ty) <$> fun.results)) + responseCon :: Response -> Q Con + responseCon resp = normalC (responseConName api resp) (defaultBangType . (.ty) <$> resp.fields) serializableTypeDerivClauses :: [Q DerivClause] serializableTypeDerivClauses = [ derivClause Nothing [[t|Eq|], [t|Show|], [t|Generic|], [t|Binary|]] ] -makeClient :: RpcApi -> Q [Dec] -makeClient api@RpcApi{functions} = do - mconcat <$> mapM makeClientFunction functions +makeClient :: Code -> Q [Dec] +makeClient code = sequence code.stubDecs + +makeServer :: RpcApi -> Code -> Q [Dec] +makeServer api@RpcApi{functions} code = sequence [protocolImplDec, logicInstanceDec] where - makeClientFunction :: RpcFunction -> Q [Dec] - makeClientFunction fun = do + protocolImplDec :: Q Dec + protocolImplDec = do + dataD (pure []) (implTypeName api) [] Nothing [recC (implTypeName api) code.serverImplFields] [] + functionImplType :: RpcFunction -> Q Type + functionImplType fun = do + argumentTypes <- functionArgumentTypes fun + streamTypes <- serverStreamTypes + buildFunctionType (pure (argumentTypes <> streamTypes)) [t|IO $(buildTupleType (functionResultTypes fun))|] + where + serverStreamTypes :: Q [Type] + serverStreamTypes = sequence $ (\stream -> [t|Stream $(stream.tyDown) $(stream.tyUp)|]) <$> fun.streams + + logicInstanceDec :: Q Dec + logicInstanceDec = instanceD (cxt []) [t|HasProtocolImpl $(protocolType api)|] [ + tySynInstD (tySynEqn Nothing [t|ProtocolImpl $(protocolType api)|] (implType api)), + requestHandler + ] + requestHandler :: Q Dec + requestHandler = do + requestHandlerPrimeName <- newName "handleRequest" + implRecordName <- newName "implementation" + channelName <- newName "onChannel" + funD 'handleRequest [clause [varP implRecordName, varP channelName] (normalB (varE requestHandlerPrimeName)) [funD requestHandlerPrimeName (requestHandlerClauses implRecordName channelName)]] + where + requestHandlerClauses :: Name -> Name -> [Q Clause] + requestHandlerClauses implRecordName channelName = (mconcat $ (requestClauses implRecordName channelName) <$> code.requests) + requestClauses :: Name -> Name -> Request -> [Q Clause] + requestClauses implRecordName channelName req = [mainClause, invalidChannelCountClause] + where + mainClause :: Q Clause + mainClause = do + channelNames <- sequence $ newName . ("channel" <>) . show <$> [0 .. (req.numPipelinedChannels - 1)] + + fieldNames <- sequence $ newName . (.name) <$> req.fields + let requestConP = conP (requestConName api req) (varP <$> fieldNames) + ctx = RequestHandlerContext { + implRecordE = varE implRecordName, + argumentEs = (varE <$> fieldNames), + channelEs = (varE <$> channelNames) + } + + clause + [requestConP, listP (varP <$> channelNames)] + (normalB (packResponse req.mResponse (req.handlerE ctx))) + [] + + invalidChannelCountClause :: Q Clause + invalidChannelCountClause = do + channelsName <- newName "newChannels" + let requestConP = conP (requestConName api req) (replicate (length req.fields) wildP) + clause + [requestConP, varP channelsName] + (normalB [|$(varE 'reportInvalidChannelCount) $(litE (integerL (toInteger req.numPipelinedChannels))) $(varE channelsName) $(varE channelName)|]) + [] + + packResponse :: Maybe Response -> Q Exp -> Q Exp + packResponse Nothing handlerE = [|Nothing <$ $(handlerE)|] + packResponse (Just response) handlerE = [|Just . $(conE (responseConName api response)) <$> $handlerE|] + + +-- * Pluggable codegen interface + +data Code = Code { + stubDecs :: [Q Dec], + serverImplFields :: [Q VarBangType], + requests :: [Request] +} +instance Semigroup Code where + x <> y = Code { + stubDecs = x.stubDecs <> y.stubDecs, + serverImplFields = x.serverImplFields <> y.serverImplFields, + requests = x.requests <> y.requests + } +instance Monoid Code where + mempty = Code { + stubDecs = [], + serverImplFields = [], + requests = [] + } + +data Request = Request { + name :: String, + fields :: [Field], + numPipelinedChannels :: Int, + mResponse :: Maybe Response, + handlerE :: RequestHandlerContext -> Q Exp +} + +data Response = Response { + name :: String, + fields :: [Field] + --numCreatedChannels :: Int +} + +data Field = Field { + name :: String, + ty :: Q Type +} +toField :: (HasField "name" a String, HasField "ty" a (Q Type)) => a -> Field +toField x = Field { name = x.name, ty = x.ty } + +data RequestHandlerContext = RequestHandlerContext { + implRecordE :: Q Exp, + argumentEs :: [Q Exp], + channelEs :: [Q Exp] +} + + +-- * Rpc function code generator + +generateFunction :: RpcApi -> RpcFunction -> Q Code +generateFunction api fun = do + stubDecs <- clientFunctionStub + pure Code { + stubDecs, + serverImplFields = + if isNothing fun.fixedHandler + then [ varDefaultBangType implFieldName implSig ] + else [], + requests = [request] + } + where + request :: Request + request = Request { + name = fun.name, + fields = toField <$> fun.arguments, + numPipelinedChannels = length fun.streams, + mResponse = if hasResult fun then Just response else Nothing, + handlerE = serverRequestHandlerE + } + response :: Response + response = Response { + name = fun.name, + -- TODO unpack? + fields = [ Field { name = "packedResponse", ty = buildTupleType (sequence ((.ty) <$> fun.results)) } ] + --numCreatedChannels = undefined + } + implFieldName :: Name + implFieldName = functionImplFieldName api fun + implSig :: Q Type + implSig = do + argumentTypes <- functionArgumentTypes fun + streamTypes <- serverStreamTypes + buildFunctionType (pure (argumentTypes <> streamTypes)) [t|IO $(buildTupleType (functionResultTypes fun))|] + where + serverStreamTypes :: Q [Type] + serverStreamTypes = sequence $ (\stream -> [t|Stream $(stream.tyDown) $(stream.tyUp)|]) <$> fun.streams + + serverRequestHandlerE :: RequestHandlerContext -> Q Exp + serverRequestHandlerE ctx = applyChannels ctx.channelEs (applyArgs (implFieldE ctx.implRecordE)) + where + implFieldE :: Q Exp -> Q Exp + implFieldE implRecordE = case fun.fixedHandler of + Nothing -> [|$(varE implFieldName) $implRecordE|] + Just handler -> [|$(handler) :: $implSig|] + applyArgs :: Q Exp -> Q Exp + applyArgs implE = foldl appE implE ctx.argumentEs + applyChannels :: [Q Exp] -> Q Exp -> Q Exp + applyChannels [] implE = implE + applyChannels (channel0E:channelEs) implE = varE 'join `appE` foldl + (\x y -> [|$x <*> $y|]) + ([|$implE <$> $(createStream channel0E)|]) + (createStream <$> channelEs) + where + createStream :: Q Exp -> Q Exp + createStream = (varE 'newStream `appE`) + + clientFunctionStub :: Q [Q Dec] + clientFunctionStub = do clientVarName <- newName "client" argNames <- sequence (newName . (.name) <$> fun.arguments) channelNames <- sequence (newName . (<> "Channel") . (.name) <$> fun.streams) @@ -138,13 +305,13 @@ makeClient api@RpcApi{functions} = do where funName :: Name funName = mkName fun.name - makeClientFunction' :: Name -> [Name] -> [Name] -> [Name] -> Q [Dec] + makeClientFunction' :: Name -> [Name] -> [Name] -> [Name] -> Q [Q Dec] makeClientFunction' clientVarName argNames channelNames streamNames = do funArgTypes <- functionArgumentTypes fun clientType <- [t|Client $(protocolType api)|] resultType <- optionalResultType streamTypes <- clientStreamTypes - sequence [ + pure [ sigD funName (buildFunctionType (pure ([clientType] <> funArgTypes)) [t|IO $(buildTupleType (pure (resultType <> streamTypes)))|]), funD funName [clause ([varP clientVarName] <> varPats) body []] ] @@ -161,10 +328,12 @@ makeClient api@RpcApi{functions} = do varPats = varP <$> argNames body :: Q Body body - | hasResult fun = normalB $ doE $ + | hasResult fun = do + responseName <- newName "response" + normalB $ doE $ [ - bindS [p|(response, resources)|] (requestE requestDataE), - bindS [p|result|] (checkResult [|response|]) + bindS [p|($(varP responseName), resources)|] (requestE requestDataE), + bindS [p|result|] (checkResult (varE responseName)) ] <> createStreams [|resources.createdChannels|] <> [noBindS [|pure $(buildTuple (liftA2 (:) [|result|] streamsE))|]] @@ -173,7 +342,7 @@ makeClient api@RpcApi{functions} = do createStreams [|resources.createdChannels|] <> [noBindS [|pure $(buildTuple streamsE)|]] requestDataE :: Q Exp - requestDataE = applyVars (conE (requestFunctionCtorName api fun)) + requestDataE = applyVars (conE (requestFunctionConName api fun)) createStreams :: Q Exp -> [Q Stmt] createStreams channelsE = if length fun.streams > 0 then [assignChannels] <> go channelNames streamNames else [verifyNoChannels] where @@ -223,111 +392,6 @@ makeClient api@RpcApi{functions} = do typedRequest :: Q Exp typedRequest = appTypeE (varE 'clientRequestBlocking) (protocolType api) - -makeServer :: RpcApi -> Q [Dec] -makeServer api@RpcApi{functions} = sequence [handlerRecordDec, logicInstanceDec] - where - handlerRecordDec :: Q Dec - handlerRecordDec = dataD (pure []) (implTypeName api) [] Nothing [recC (implTypeName api) (handlerRecordField <$> functionsWithoutBuiltinHandler)] [] - functionsWithoutBuiltinHandler :: [RpcFunction] - functionsWithoutBuiltinHandler = filter (isNothing . fixedHandler) functions - handlerRecordField :: RpcFunction -> Q VarBangType - handlerRecordField fun = varDefaultBangType (implFieldName api fun) (handlerFunctionType fun) - handlerFunctionType :: RpcFunction -> Q Type - handlerFunctionType fun = do - argumentTypes <- functionArgumentTypes fun - streamTypes <- serverStreamTypes - buildFunctionType (pure (argumentTypes <> streamTypes)) [t|IO $(buildTupleType (functionResultTypes fun))|] - where - serverStreamTypes :: Q [Type] - serverStreamTypes = sequence $ (\stream -> [t|Stream $(stream.tyDown) $(stream.tyUp)|]) <$> fun.streams - - logicInstanceDec :: Q Dec - logicInstanceDec = instanceD (cxt []) [t|HasProtocolImpl $(protocolType api)|] [ - tySynInstD (tySynEqn Nothing [t|ProtocolImpl $(protocolType api)|] (implType api)), - messageHandler - ] - messageHandler :: Q Dec - messageHandler = do - handleMessagePrimeName <- newName "handleMessage" - implName <- newName "impl" - channelsName <- newName "channels" - funD 'handleMessage [clause [varP implName, varP channelsName] (normalB (varE handleMessagePrimeName)) [handleMessagePrimeDec handleMessagePrimeName (varE implName) (varE channelsName)]] - where - handleMessagePrimeDec :: Name -> Q Exp -> Q Exp -> Q Dec - handleMessagePrimeDec handleMessagePrimeName implE channelsE = funD handleMessagePrimeName (handlerFunctionClause <$> functions) - where - handlerFunctionClause :: RpcFunction -> Q Clause - handlerFunctionClause fun = do - argNames <- sequence (newName . (.name) <$> fun.arguments) - channelNames <- sequence (newName . (<> "Channel") . (.name) <$> fun.streams) - streamNames <- sequence (newName . (.name) <$> fun.streams) - serverLogicHandlerFunctionClause' argNames channelNames streamNames - where - serverLogicHandlerFunctionClause' :: [Name] -> [Name] -> [Name] -> Q Clause - serverLogicHandlerFunctionClause' argNames channelNames streamNames = clause [conP (requestFunctionCtorName api fun) varPats] body [] - where - varPats :: [Q Pat] - varPats = varP <$> argNames - body :: Q Body - body = normalB $ doE $ createStreams <> [callImplementation] - createStreams :: [Q Stmt] - createStreams = if length fun.streams > 0 then [assignChannels] <> go channelNames streamNames else [verifyNoChannels] - where - verifyNoChannels :: Q Stmt - verifyNoChannels = noBindS [|unless (null $(channelsE)) (fail "Received invalid channel count")|] -- TODO channelReportProtocolError - assignChannels :: Q Stmt - assignChannels = - bindS - (tupP (varP <$> channelNames)) - $ caseE channelsE [ - match (listP (varP <$> channelNames)) (normalB [|pure $(tupE (varE <$> channelNames))|]) [], - match [p|_|] (normalB [|fail "Received invalid channel count"|]) [] -- TODO channelReportProtocolError - ] - go :: [Name] -> [Name] -> [Q Stmt] - go [] [] = [] - go (cn:cns) (sn:sns) = createStream cn sn : go cns sns - go _ _ = fail "Logic error: lists have different lengths" - createStream :: Name -> Name -> Q Stmt - createStream channelName streamName = bindS (varP streamName) [|$(varE 'newStream) $(varE channelName)|] - callImplementation :: Q Stmt - callImplementation = noBindS callImplementationE - callImplementationE :: Q Exp - callImplementationE - | hasResult fun = [|Just <$> $(packResponse (applyStreams (applyArguments implExp)))|] - | otherwise = [|Nothing <$ $(applyStreams (applyArguments implExp))|] - packResponse :: Q Exp -> Q Exp - packResponse = fmapE (conE (responseFunctionCtorName api fun)) - applyArguments :: Q Exp -> Q Exp - applyArguments = go argNames - where - go :: [Name] -> Q Exp -> Q Exp - go [] ex = ex - go (n:ns) ex = go ns (appE ex (varE n)) - applyStreams :: Q Exp -> Q Exp - applyStreams = go streamNames - where - go :: [Name] -> Q Exp -> Q Exp - go [] ex = ex - go (sn:sns) ex = go sns (appE ex (varE sn)) - implExp :: Q Exp - implExp = implExp' fun.fixedHandler - where - implExp' :: Maybe (Q Exp) -> Q Exp - implExp' Nothing = varE (implFieldName api fun) `appE` implE - implExp' (Just handler) = [| - let - impl :: $(implSig) - impl = $(handler) - in impl - |] - implSig :: Q Type - implSig = handlerFunctionType fun - --- * Internal - --- ** Protocol generator helpers - functionArgumentTypes :: RpcFunction -> Q [Type] functionArgumentTypes fun = sequence $ (.ty) <$> fun.arguments functionResultTypes :: RpcFunction -> Q [Type] @@ -337,7 +401,7 @@ hasResult :: RpcFunction -> Bool hasResult fun = not (null fun.results) --- *** Name helper functions +-- ** Name helper functions protocolTypeName :: RpcApi -> Name protocolTypeName RpcApi{name} = mkName (name <> "Protocol") @@ -351,8 +415,11 @@ requestTypeIdentifier RpcApi{name} = name <> "ProtocolRequest" requestTypeName :: RpcApi -> Name requestTypeName = mkName . requestTypeIdentifier -requestFunctionCtorName :: RpcApi -> RpcFunction -> Name -requestFunctionCtorName api fun = mkName (requestTypeIdentifier api <> "_" <> fun.name) +requestFunctionConName :: RpcApi -> RpcFunction -> Name +requestFunctionConName api fun = mkName (requestTypeIdentifier api <> "_" <> fun.name) + +requestConName :: RpcApi -> Request -> Name +requestConName api req = mkName (requestTypeIdentifier api <> "_" <> req.name) responseTypeIdentifier :: RpcApi -> String responseTypeIdentifier RpcApi{name} = name <> "ProtocolResponse" @@ -363,16 +430,19 @@ responseTypeName = mkName . responseTypeIdentifier responseFunctionCtorName :: RpcApi -> RpcFunction -> Name responseFunctionCtorName api fun = mkName (responseTypeIdentifier api <> "_" <> fun.name) +responseConName :: RpcApi -> Response -> Name +responseConName api resp = mkName (responseTypeIdentifier api <> "_" <> resp.name) + implTypeName :: RpcApi -> Name implTypeName RpcApi{name} = mkName $ name <> "ProtocolImpl" implType :: RpcApi -> Q Type implType = conT . implTypeName -implFieldName :: RpcApi -> RpcFunction -> Name -implFieldName _api fun = mkName (fun.name <> "Impl") +functionImplFieldName :: RpcApi -> RpcFunction -> Name +functionImplFieldName _api fun = mkName (fun.name <> "Impl") --- ** Template Haskell helper functions +-- * Template Haskell helper functions funT :: Q Type -> Q Type -> Q Type funT x = appT (appT arrowT x) @@ -412,3 +482,11 @@ varDefaultBangType name qType = varBangType name $ bangType (bang noSourceUnpack fmapE :: Q Exp -> Q Exp -> Q Exp fmapE f e = [|$(f) <$> $(e)|] + +-- * Error reporting + +reportInvalidChannelCount :: Int -> [Channel] -> Channel -> IO a +reportInvalidChannelCount expectedCount newChannels onChannel = channelReportProtocolError onChannel msg + where + msg = mconcat parts + parts = ["Received ", show (length newChannels), " new channels, but expected ", show expectedCount]