diff --git a/flake.lock b/flake.lock index ab8fa280910e40e070c2769ea15db57799da55f1..e8954ad0db1471f4f3c8dd1f22bbd94c05fc5163 100644 --- a/flake.lock +++ b/flake.lock @@ -21,11 +21,11 @@ }, "locked": { "host": "git.c3pb.de", - "lastModified": 1626923341, - "narHash": "sha256-CWvh6F6d1kEN6IpMvDBxSBNl4oJP2FhRGU5uGLwZSBw=", + "lastModified": 1628533241, + "narHash": "sha256-nvAqgEzmdYhvwTb0y6Vico4EvyOT1ehgzZU7/LrnW2g=", "owner": "jens", "repo": "quasar", - "rev": "458784d70f664f3af9b98655505ca93e72610376", + "rev": "894908377f0ee5cf626a6bdd8c4fdd29411b8e80", "type": "gitlab" }, "original": { diff --git a/quasar-network.cabal b/quasar-network.cabal index f50a4e5d91488327fbd62d9313872377f3fc726a..a4b02b4260fe423c82dd3b93a8299d656ef81579 100644 --- a/quasar-network.cabal +++ b/quasar-network.cabal @@ -85,6 +85,7 @@ library Quasar.Network.Connection Quasar.Network.Multiplexer Quasar.Network.Runtime + Quasar.Network.Runtime.Observable Quasar.Network.SocketLocation Quasar.Network.TH hs-source-dirs: diff --git a/src/Quasar/Network.hs b/src/Quasar/Network.hs index fa25d1eea7f7f63dbae6558509092a76fe5c67d1..47fa25f06d20154cc4194f6329b17fc5f63a02f1 100644 --- a/src/Quasar/Network.hs +++ b/src/Quasar/Network.hs @@ -11,6 +11,7 @@ module Quasar.Network ( addResult, addStream, setFixedHandler, + rpcObservable, -- * Runtime diff --git a/src/Quasar/Network/Runtime.hs b/src/Quasar/Network/Runtime.hs index 567b878c33c2a0c9b613211dab07e81eeedcfd5d..64323920f02bcfc70aecdb56be36b933718038cd 100644 --- a/src/Quasar/Network/Runtime.hs +++ b/src/Quasar/Network/Runtime.hs @@ -47,7 +47,7 @@ import Data.Binary (Binary, encode, decodeOrFail) import qualified Data.ByteString.Lazy as BSL import qualified Data.HashMap.Strict as HM import qualified Network.Socket as Socket -import Quasar.Core +import Quasar.Awaitable import Quasar.Network.Connection import Quasar.Network.Multiplexer import Quasar.Prelude @@ -82,13 +82,13 @@ emptyClientState = ClientState { clientSend :: forall p m. (MonadIO m, RpcProtocol p) => Client p -> MessageConfiguration -> ProtocolRequest p -> m SentMessageResources clientSend client config req = liftIO $ channelSend_ client.channel config (encode req) -clientRequest :: forall p m a. (MonadIO m, RpcProtocol p) => Client p -> (ProtocolResponse p -> Maybe a) -> MessageConfiguration -> ProtocolRequest p -> m (Async a, SentMessageResources) +clientRequest :: forall p m a. (MonadIO m, RpcProtocol p) => Client p -> (ProtocolResponse p -> Maybe a) -> MessageConfiguration -> ProtocolRequest p -> m (Awaitable a, SentMessageResources) clientRequest client checkResponse config req = do resultAsync <- newAsyncVar sentMessageResources <- liftIO $ channelSend client.channel config (encode req) $ \msgId -> modifyMVar_ client.stateMVar $ \state -> pure state{callbacks = HM.insert msgId (requestCompletedCallback resultAsync msgId) state.callbacks} - pure (toAsync resultAsync, sentMessageResources) + pure (toAwaitable resultAsync, sentMessageResources) where requestCompletedCallback :: AsyncVar a -> MessageId -> ProtocolResponse p -> IO () requestCompletedCallback resultAsync msgId response = do @@ -97,7 +97,7 @@ clientRequest client checkResponse config req = do case checkResponse response of Nothing -> clientReportProtocolError client "Invalid response" - Just result -> putAsyncVar resultAsync result + Just result -> putAsyncVar_ resultAsync result clientHandleChannelMessage :: forall p. (RpcProtocol p) => Client p -> ReceivedMessageResources -> BSL.ByteString -> IO () clientHandleChannelMessage client resources msg = case decodeOrFail msg of @@ -153,14 +153,14 @@ streamClose (Stream channel) = liftIO $ channelClose channel -- ** Running client and server -withClientTCP :: RpcProtocol p => Socket.HostName -> Socket.ServiceName -> (Client p -> AsyncIO a) -> IO a +withClientTCP :: RpcProtocol p => Socket.HostName -> Socket.ServiceName -> (Client p -> IO a) -> IO a withClientTCP host port = withClientBracket (newClientTCP host port) newClientTCP :: forall p. RpcProtocol p => Socket.HostName -> Socket.ServiceName -> IO (Client p) newClientTCP host port = newClient =<< connectTCP host port -withClientUnix :: RpcProtocol p => FilePath -> (Client p -> AsyncIO a) -> IO a +withClientUnix :: RpcProtocol p => FilePath -> (Client p -> IO a) -> IO a withClientUnix socketPath = withClientBracket (newClientUnix socketPath) newClientUnix :: RpcProtocol p => FilePath -> IO (Client p) @@ -170,14 +170,14 @@ newClientUnix socketPath = bracketOnError (Socket.socket Socket.AF_UNIX Socket.S newClient sock -withClient :: forall p a b. (IsConnection a, RpcProtocol p) => a -> (Client p -> AsyncIO b) -> IO b +withClient :: forall p a b. (IsConnection a, RpcProtocol p) => a -> (Client p -> IO b) -> IO b withClient connection = withClientBracket (newClient connection) newClient :: forall p a. (IsConnection a, RpcProtocol p) => a -> IO (Client p) newClient connection = newChannelClient =<< newMultiplexer MultiplexerSideA (toSocketConnection connection) -withClientBracket :: forall p a. (RpcProtocol p) => IO (Client p) -> (Client p -> AsyncIO a) -> IO a -withClientBracket createClient action = bracket createClient clientClose $ \client -> runAsyncIO (action client) +withClientBracket :: forall p a. (RpcProtocol p) => IO (Client p) -> (Client p -> IO a) -> IO a +withClientBracket createClient = bracket createClient clientClose newChannelClient :: RpcProtocol p => Channel -> IO (Client p) @@ -293,8 +293,8 @@ runServerHandler protocolImpl = runMultiplexer MultiplexerSideB registerChannelS registerChannelServerHandler channel = channelSetHandler channel (serverHandleChannelMessage @p protocolImpl channel) -withLocalClient :: forall p a. (RpcProtocol p, HasProtocolImpl p) => Server p -> ((Client p) -> AsyncIO a) -> IO a -withLocalClient server action = bracket (newLocalClient server) clientClose $ \client -> runAsyncIO (action client) +withLocalClient :: forall p m a. (RpcProtocol p, HasProtocolImpl p) => Server p -> (Client p -> IO a) -> IO a +withLocalClient server = bracket (newLocalClient server) clientClose newLocalClient :: forall p. (RpcProtocol p, HasProtocolImpl p) => Server p -> IO (Client p) newLocalClient server = do @@ -306,5 +306,5 @@ newLocalClient server = do -- ** Test implementation -withStandaloneClient :: forall p a. (RpcProtocol p, HasProtocolImpl p) => ProtocolImpl p -> (Client p -> AsyncIO a) -> IO a +withStandaloneClient :: forall p a. (RpcProtocol p, HasProtocolImpl p) => ProtocolImpl p -> (Client p -> IO a) -> IO a withStandaloneClient impl runClientHook = withServer impl [] $ \server -> withLocalClient server runClientHook diff --git a/src/Quasar/Network/Runtime/Observable.hs b/src/Quasar/Network/Runtime/Observable.hs new file mode 100644 index 0000000000000000000000000000000000000000..49ae3b74143217199556c1239647b2d6547ef0cb --- /dev/null +++ b/src/Quasar/Network/Runtime/Observable.hs @@ -0,0 +1,12 @@ +module Quasar.Network.Runtime.Observable () where + +import Quasar.Network.Runtime +import Quasar.Core +import Quasar.Observable +import Quasar.Prelude + +newNetworkObservable + :: ((ObservableMessage v -> IO ()) -> IO Disposable) + -> (forall m. HasResourceManager m => m (Task v)) + -> IO (Observable v) +newNetworkObservable observeFn retrieveFn = pure $ fnObservable observeFn retrieveFn diff --git a/src/Quasar/Network/TH.hs b/src/Quasar/Network/TH.hs index e25c7bef7c4cc78401adc6a5a97425cfe0e0f3f8..503d87af3394d93504f81efb2866a8a7f875cb67 100644 --- a/src/Quasar/Network/TH.hs +++ b/src/Quasar/Network/TH.hs @@ -10,6 +10,7 @@ module Quasar.Network.TH ( addResult, addStream, setFixedHandler, + rpcObservable, makeRpc, -- TODO: re-add functions that generate only client and server later RpcProtocol(ProtocolRequest, ProtocolResponse), @@ -19,18 +20,20 @@ module Quasar.Network.TH ( import Control.Monad.State (State, execState) import qualified Control.Monad.State as State import Data.Binary (Binary) -import Data.Maybe (isNothing) +import Data.Maybe (isJust, isNothing) import GHC.Records.Compat (HasField) import Language.Haskell.TH hiding (interruptible) import Language.Haskell.TH.Syntax -import Quasar.Core +import Quasar.Awaitable import Quasar.Network.Multiplexer import Quasar.Network.Runtime +import Quasar.Observable import Quasar.Prelude data RpcApi = RpcApi { name :: String, - functions :: [ RpcFunction ] + functions :: [ RpcFunction ], + observables :: [ RpcObservable ] } data RpcFunction = RpcFunction { @@ -57,20 +60,36 @@ data RpcStream = RpcStream { tyDown :: Q Type } -rpcApi :: String -> [RpcFunction] -> RpcApi -rpcApi apiName functions = RpcApi { +data RpcObservable = RpcObservable { + name :: String, + ty :: Q Type +} + +rpcApi :: String -> State RpcApi () -> RpcApi +rpcApi apiName setup = execState setup RpcApi { name = apiName, - functions = functions + functions = [], + observables = [] } -rpcFunction :: String -> State RpcFunction () -> RpcFunction -rpcFunction methodName setup = execState setup RpcFunction { - name = methodName, - arguments = [], - results = [], - streams = [], - fixedHandler = Nothing - } +rpcFunction :: String -> State RpcFunction () -> State RpcApi () +rpcFunction methodName setup = State.modify (\api -> api{functions = api.functions <> [fun]}) + where + fun = execState setup RpcFunction { + name = methodName, + arguments = [], + results = [], + streams = [], + fixedHandler = Nothing + } + +rpcObservable :: String -> Q Type -> State RpcApi () +rpcObservable name ty = State.modify (\api -> api{observables = api.observables <> [observable]}) + where + observable = RpcObservable { + name, + ty + } addArgument :: String -> Q Type -> State RpcFunction () addArgument name t = State.modify (\fun -> fun{arguments = fun.arguments <> [RpcArgument name t]}) @@ -90,8 +109,8 @@ setFixedHandler handler = State.modify (\fun -> fun{fixedHandler = Just handler} -- | Generates rpc protocol types, rpc client and rpc server makeRpc :: RpcApi -> Q [Dec] makeRpc api = do - code <- mconcat <$> sequence (generateFunction api <$> api.functions) - mconcat <$> sequence [makeProtocol api code, makeClient code, makeServer api code] + code <- mconcat <$> sequence ((generateFunction api <$> api.functions) <> (generateObservable api <$> api.observables)) + mconcat <$> sequence [makeProtocol api code, makeClient api code, makeServer api code] makeProtocol :: RpcApi -> Code -> Q [Dec] makeProtocol api code = sequence [protocolDec, protocolInstanceDec, requestDec, responseDec] @@ -124,8 +143,106 @@ makeProtocol api code = sequence [protocolDec, protocolInstanceDec, requestDec, derivClause (Just AnyclassStrategy) [[t|Binary|]] ] -makeClient :: Code -> Q [Dec] -makeClient code = sequence code.clientStubDecs +makeClient :: RpcApi -> Code -> Q [Dec] +makeClient api code = do + requestStubDecs <- mconcat <$> sequence (clientRequestStub api <$> code.requests) + sequence $ code.clientStubDecs <> requestStubDecs + +clientRequestStub :: RpcApi -> Request -> Q [Q Dec] +clientRequestStub api req = do + clientStubPrimeName <- newName req.name + clientVarName <- newName "client" + argNames <- sequence (newName . (.name) <$> req.fields) + clientRequestStub' clientStubPrimeName clientVarName argNames + where + stubName :: Name + stubName = clientRequestStubName api req + makeStubSig :: Q [Type] -> Q Type + makeStubSig arguments = + [t|forall m. MonadIO m => $(buildFunctionType arguments [t|m $(buildTupleType (liftA2 (<>) optionalResultType resourceTypes))|])|] + resourceTypes :: Q [Type] + resourceTypes = sequence $ resourceType <$> req.createdResources + optionalResultType :: Q [Type] + optionalResultType = case req.mResponse of + Nothing -> pure [] + Just resp -> sequence [[t|Awaitable $(buildTupleType (sequence ((.ty) <$> resp.fields)))|]] + resourceType :: RequestCreateResource -> Q Type + resourceType RequestCreateChannel = [t|Channel|] + resourceType (RequestCreateStream up down) = [t|Stream $up $down|] + + clientRequestStub' :: Name -> Name -> [Name] -> Q [Q Dec] + clientRequestStub' clientStubPrimeName clientVarName argNames = do + pure [ + clientRequestStubSigDec api req, + funD stubName [clause ([varP clientVarName] <> varPats) body clientStubPrimeDecs] + ] + where + clientE :: Q Exp + clientE = varE clientVarName + varPats :: [Q Pat] + varPats = varP <$> argNames + body :: Q Body + body = case req.mResponse of + Just resp -> normalB [|$(requestE resp requestDataE) >>= \(result, resources) -> $(varE clientStubPrimeName) result resources.createdChannels|] + Nothing -> normalB [|$(sendE requestDataE) >>= \resources -> $(varE clientStubPrimeName) resources.createdChannels|] + clientStubPrimeDecs :: [Q Dec] + clientStubPrimeDecs = [ + sigD clientStubPrimeName (makeStubSig (liftA2 (<>) optionalResultType (sequence [[t|[Channel]|]]))), + funD clientStubPrimeName (clientStubPrimeClauses req) + ] + clientStubPrimeClauses :: Request -> [Q Clause] + clientStubPrimeClauses req = [mainClause, invalidChannelCountClause] + where + mainClause :: Q Clause + mainClause = do + resultAsyncName <- newName "result" + + channelNames <- sequence $ newName . ("channel" <>) . show <$> [0 .. (numPipelinedChannels req - 1)] + + clause + (whenHasResult (varP resultAsyncName) <> [listP (varP <$> channelNames)]) + (normalB (buildTupleM (sequence (whenHasResult [|pure $(varE resultAsyncName)|] <> ((\x -> [|newStream $(varE x)|]) <$> channelNames))))) + [] + + invalidChannelCountClause :: Q Clause + invalidChannelCountClause = do + channelsName <- newName "newChannels" + clause + (whenHasResult wildP <> [varP channelsName]) + (normalB [|$(varE 'multiplexerInvalidChannelCount) $(litE (integerL (toInteger (numPipelinedChannels req)))) $(varE channelsName)|]) + [] + hasResponse :: Bool + hasResponse = isJust req.mResponse + whenHasResult :: a -> [a] + whenHasResult x = [x | hasResponse] + requestDataE :: Q Exp + requestDataE = applyVars (conE (requestConName api req)) + messageConfigurationE :: Q Exp + messageConfigurationE = [|defaultMessageConfiguration{createChannels = $(litE $ integerL $ toInteger $ numPipelinedChannels req)}|] + sendE :: Q Exp -> Q Exp + sendE msgExp = [|$typedSend $clientE $messageConfigurationE $msgExp|] + requestE :: Response -> Q Exp -> Q Exp + requestE resp msgExp = [|$typedRequest $clientE $checkResult $messageConfigurationE $msgExp|] + where + checkResult :: Q Exp + checkResult = lamCaseE [valid, invalid] + valid :: Q Match + valid = do + result <- newName "result" + match (conP (responseConName api resp) [varP result]) (normalB [|pure $(varE result)|]) [] + invalid :: Q Match + invalid = match wildP (normalB [|Nothing|]) [] + applyVars :: Q Exp -> Q Exp + applyVars = go argNames + where + go :: [Name] -> Q Exp -> Q Exp + go [] ex = ex + go (n:ns) ex = go ns (appE ex (varE n)) + -- check if the response to a request matches the expected response constructor + typedSend :: Q Exp + typedSend = appTypeE (varE 'clientSend) (protocolType api) + typedRequest :: Q Exp + typedRequest = appTypeE (varE 'clientRequest) (protocolType api) makeServer :: RpcApi -> Code -> Q [Dec] makeServer api@RpcApi{functions} code = sequence [protocolImplDec, logicInstanceDec] @@ -153,7 +270,7 @@ makeServer api@RpcApi{functions} code = sequence [protocolImplDec, logicInstance where mainClause :: Q Clause mainClause = do - channelNames <- sequence $ newName . ("channel" <>) . show <$> [0 .. (req.numPipelinedChannels - 1)] + channelNames <- sequence $ newName . ("channel" <>) . show <$> [0 .. (numPipelinedChannels req - 1)] fieldNames <- sequence $ newName . (.name) <$> req.fields let requestConP = conP (requestConName api req) (varP <$> fieldNames) @@ -174,7 +291,7 @@ makeServer api@RpcApi{functions} code = sequence [protocolImplDec, logicInstance let requestConP = conP (requestConName api req) (replicate (length req.fields) wildP) clause [requestConP, varP channelsName] - (normalB [|$(varE 'reportInvalidChannelCount) $(litE (integerL (toInteger req.numPipelinedChannels))) $(varE channelsName) $(varE channelName)|]) + (normalB [|$(varE 'reportInvalidChannelCount) $(litE (integerL (toInteger (numPipelinedChannels req)))) $(varE channelsName) $(varE channelName)|]) [] packResponse :: Maybe Response -> Q Exp -> Q Exp @@ -205,11 +322,13 @@ instance Monoid Code where data Request = Request { name :: String, fields :: [Field], - numPipelinedChannels :: Int, + createdResources :: [RequestCreateResource], mResponse :: Maybe Response, handlerE :: RequestHandlerContext -> Q Exp } +data RequestCreateResource = RequestCreateChannel | RequestCreateStream (Q Type) (Q Type) + data Response = Response { name :: String, fields :: [Field] @@ -232,6 +351,44 @@ data RequestHandlerContext = RequestHandlerContext { -- * Rpc function code generator +generateObservable :: RpcApi -> RpcObservable -> Q Code +generateObservable api observable = pure Code { + clientStubDecs = [], + requests = [observeRequest, retrieveRequest], + serverImplFields = [varDefaultBangType serverImplFieldName serverImplFieldSig] +} + where + observeRequest :: Request + observeRequest = Request { + name = observable.name <> "_observe", + fields = [], + createdResources = [RequestCreateStream [t|Void|] observable.ty], + mResponse = Nothing, + handlerE = \ctx -> [|undefined|] + } + retrieveRequest :: Request + retrieveRequest = Request { + name = observable.name <> "_retrieve", + fields = [], + createdResources = [], + mResponse = Just retrieveResponse, + -- TODO use awaitable for result instead of blocking the network thread + handlerE = \ctx -> [|withDefaultResourceManager (awaitResult (retrieve $(observableE ctx)))|] + } + retrieveResponse :: Response + retrieveResponse = Response { + name = observable.name <> "_retrieve", + fields = [Field "result" observable.ty] + } + serverImplFieldName :: Name + serverImplFieldName = mkName observable.name + serverImplFieldSig :: Q Type + serverImplFieldSig = [t|Observable $(observable.ty)|] + observableE :: RequestHandlerContext -> Q Exp + observableE ctx = [|$(varE serverImplFieldName) $(ctx.implRecordE)|] + --observeE :: Q Exp + --observeE = r + generateFunction :: RpcApi -> RpcFunction -> Q Code generateFunction api fun = do clientStubDecs <- clientFunctionStub @@ -248,7 +405,7 @@ generateFunction api fun = do request = Request { name = fun.name, fields = toField <$> fun.arguments, - numPipelinedChannels = length fun.streams, + createdResources = (\stream -> RequestCreateStream stream.tyUp stream.tyDown) <$> fun.streams, mResponse = if hasResult fun then Just response else Nothing, handlerE = serverRequestHandlerE } @@ -286,102 +443,28 @@ generateFunction api fun = do clientFunctionStub :: Q [Q Dec] clientFunctionStub = do - clientStubPrimeName <- newName fun.name - clientVarName <- newName "client" - argNames <- sequence (newName . (.name) <$> fun.arguments) - makeClientFunction' clientStubPrimeName clientVarName argNames + funArgTypes <- functionArgumentTypes fun + clientType <- [t|Client $(protocolType api)|] + pure [ + sigD funName (clientRequestStubSig api request), + funD funName [clause [] (normalB (clientRequestStubE api request)) []] + ] where funName :: Name funName = mkName fun.name - makeClientFunction' :: Name -> Name -> [Name] -> Q [Q Dec] - makeClientFunction' clientStubPrimeName clientVarName argNames = do - funArgTypes <- functionArgumentTypes fun - clientType <- [t|Client $(protocolType api)|] - pure [ - sigD funName (makeStubSig (pure (clientType : funArgTypes))), - funD funName [clause ([varP clientVarName] <> varPats) body clientStubPrimeDecs] - ] - where - makeStubSig :: Q [Type] -> Q Type - makeStubSig arguments = buildFunctionType arguments [t|AsyncIO $(buildTupleType (liftA2 (<>) optionalResultType clientStreamTypes))|] - optionalResultType :: Q [Type] - optionalResultType = sequence $ whenHasResult [t|Async $(buildTupleType (functionResultTypes fun))|] - clientStreamTypes :: Q [Type] - clientStreamTypes = sequence $ (\stream -> [t|Stream $(stream.tyUp) $(stream.tyDown)|]) <$> fun.streams - clientE :: Q Exp - clientE = varE clientVarName - varPats :: [Q Pat] - varPats = varP <$> argNames - body :: Q Body - body - | hasResult fun = normalB ([|$(requestE requestDataE) >>= \(result, resources) -> $(varE clientStubPrimeName) result resources.createdChannels|]) - | otherwise = normalB ([|$(sendE requestDataE) >>= \resources -> $(varE clientStubPrimeName) resources.createdChannels|]) - clientStubPrimeDecs :: [Q Dec] - clientStubPrimeDecs = [ - sigD clientStubPrimeName (makeStubSig (liftA2 (<>) optionalResultType (sequence [[t|[Channel]|]]))), - funD clientStubPrimeName (clientStubPrimeClauses request) - ] - clientStubPrimeClauses :: Request -> [Q Clause] - clientStubPrimeClauses req = [mainClause, invalidChannelCountClause] - where - mainClause :: Q Clause - mainClause = do - resultAsyncName <- newName "result" - - channelNames <- sequence $ newName . ("channel" <>) . show <$> [0 .. (req.numPipelinedChannels - 1)] - - clause - (whenHasResult (varP resultAsyncName) <> [listP (varP <$> channelNames)]) - (normalB (buildTupleM (sequence (whenHasResult [|pure $(varE resultAsyncName)|] <> ((\x -> [|newStream $(varE x)|]) <$> channelNames))))) - [] - - invalidChannelCountClause :: Q Clause - invalidChannelCountClause = do - channelsName <- newName "newChannels" - clause - (whenHasResult wildP <> [varP channelsName]) - (normalB [|$(varE 'multiplexerInvalidChannelCount) $(litE (integerL (toInteger req.numPipelinedChannels))) $(varE channelsName)|]) - [] - whenHasResult :: a -> [a] - whenHasResult x = if hasResult fun then [x] else [] - requestDataE :: Q Exp - requestDataE = applyVars (conE (requestFunctionConName api fun)) - messageConfigurationE :: Q Exp - messageConfigurationE = [|defaultMessageConfiguration{createChannels = $(litE $ integerL $ toInteger $ length fun.streams)}|] - sendE :: Q Exp -> Q Exp - sendE msgExp = [|$typedSend $clientE $messageConfigurationE $msgExp|] - requestE :: Q Exp -> Q Exp - requestE msgExp = [|$typedRequest $clientE $checkResult $messageConfigurationE $msgExp|] - applyVars :: Q Exp -> Q Exp - applyVars = go argNames - where - go :: [Name] -> Q Exp -> Q Exp - go [] ex = ex - go (n:ns) ex = go ns (appE ex (varE n)) - -- check if the response to a request matches the expected response constructor - checkResult :: Q Exp - checkResult = lamCaseE [valid, invalid] - where - valid :: Q Match - valid = do - result <- newName "result" - match (conP (responseFunctionCtorName api fun) [varP result]) (normalB [|pure $(varE result)|]) [] - invalid :: Q Match - invalid = match wildP (normalB [|Nothing|]) [] - - typedSend :: Q Exp - typedSend = appTypeE (varE 'clientSend) (protocolType api) - typedRequest :: Q Exp - typedRequest = appTypeE (varE 'clientRequest) (protocolType api) + functionArgumentTypes :: RpcFunction -> Q [Type] functionArgumentTypes fun = sequence $ (.ty) <$> fun.arguments + functionResultTypes :: RpcFunction -> Q [Type] functionResultTypes fun = sequence $ (.ty) <$> fun.results hasResult :: RpcFunction -> Bool hasResult fun = not (null fun.results) +numPipelinedChannels :: Request -> Int +numPipelinedChannels req = length req.createdResources -- ** Name helper functions @@ -397,9 +480,6 @@ requestTypeIdentifier RpcApi{name} = name <> "ProtocolRequest" requestTypeName :: RpcApi -> Name requestTypeName = mkName . requestTypeIdentifier -requestFunctionConName :: RpcApi -> RpcFunction -> Name -requestFunctionConName api fun = mkName (requestTypeIdentifier api <> "_" <> fun.name) - requestConName :: RpcApi -> Request -> Name requestConName api req = mkName (requestTypeIdentifier api <> "_" <> req.name) @@ -424,6 +504,34 @@ implType = conT . implTypeName functionImplFieldName :: RpcApi -> RpcFunction -> Name functionImplFieldName _api fun = mkName (fun.name <> "Impl") +clientRequestStubName :: RpcApi -> Request -> Name +clientRequestStubName api req = mkName ("_" <> api.name <> "_" <> req.name) + +clientRequestStubE :: RpcApi -> Request -> Q Exp +clientRequestStubE api req = (varE (clientRequestStubName api req)) + +clientRequestStubSig :: RpcApi -> Request -> Q Type +clientRequestStubSig api req = do + reqFieldTypes <- sequence $ (.ty) <$> req.fields + clientType <- [t|Client $(protocolType api)|] + makeStubSig (pure (clientType : reqFieldTypes)) + where + makeStubSig :: Q [Type] -> Q Type + makeStubSig arguments = + [t|forall m. MonadIO m => $(buildFunctionType arguments [t|m $(buildTupleType (liftA2 (<>) optionalResultType resourceTypes))|])|] + resourceTypes :: Q [Type] + resourceTypes = sequence $ resourceType <$> req.createdResources + optionalResultType :: Q [Type] + optionalResultType = case req.mResponse of + Nothing -> pure [] + Just resp -> sequence [[t|Awaitable $(buildTupleType (sequence ((.ty) <$> resp.fields)))|]] + resourceType :: RequestCreateResource -> Q Type + resourceType RequestCreateChannel = [t|Channel|] + resourceType (RequestCreateStream up down) = [t|Stream $up $down|] + +clientRequestStubSigDec :: RpcApi -> Request -> Q Dec +clientRequestStubSigDec api req = sigD (clientRequestStubName api req) (clientRequestStubSig api req) + -- * Template Haskell helper functions funT :: Q Type -> Q Type -> Q Type @@ -442,6 +550,8 @@ buildTupleType fields = buildTupleType' =<< fields go t (f:fs) = go (AppT t f) fs -- | [a, b, c] -> (a, b, c) +-- [a] -> a +-- [] -> () buildTuple :: Q [Exp] -> Q Exp buildTuple fields = buildTuple' =<< fields where @@ -451,6 +561,8 @@ buildTuple fields = buildTuple' =<< fields buildTuple' fs = pure $ TupE (Just <$> fs) -- | [m a, m b, m c] -> m (a, b, c) +-- [m a] -> m a +-- [] -> m () buildTupleM :: Q [Exp] -> Q Exp buildTupleM fields = buildTuple' =<< fields where diff --git a/test/Quasar/NetworkSpec.hs b/test/Quasar/NetworkSpec.hs index abbeaaced33b7490bba661d39ae2d479cfdaf1c9..014a9a343dc183f944f425bea810ce98037c818b 100644 --- a/test/Quasar/NetworkSpec.hs +++ b/test/Quasar/NetworkSpec.hs @@ -11,8 +11,9 @@ module Quasar.NetworkSpec where import Control.Concurrent.MVar -import Control.Monad.IO.Class (liftIO) +import Control.Monad.IO.Class (MonadIO, liftIO) import Prelude +import Quasar.Awaitable import Quasar.Core import Quasar.Network import Quasar.Network.Runtime (withStandaloneClient) @@ -25,38 +26,43 @@ shouldReturnAsync :: (HasCallStack, Show a, Eq a) => AsyncIO a -> a -> AsyncIO ( action `shouldReturnAsync` expected = action >>= liftIO . (`shouldBe` expected) -$(makeRpc $ rpcApi "Example" [ +$(makeRpc $ rpcApi "Example" $ do rpcFunction "fixedHandler42" $ do addArgument "arg" [t|Int|] addResult "result" [t|Bool|] - setFixedHandler [| pure . (== 42) |], + setFixedHandler [| pure . (== 42) |] + rpcFunction "fixedHandlerInc" $ do addArgument "arg" [t|Int|] addResult "result" [t|Int|] - setFixedHandler [| pure . (+ 1) |], + setFixedHandler [| pure . (+ 1) |] + rpcFunction "multiArgs" $ do addArgument "one" [t|Int|] addArgument "two" [t|Int|] addArgument "three" [t|Bool|] addResult "result" [t|Int|] - addResult "result2" [t|Bool|], + addResult "result2" [t|Bool|] + rpcFunction "noArgs" $ do - addResult "result" [t|Int|], + addResult "result" [t|Int|] + rpcFunction "noResponse" $ do - addArgument "arg" [t|Int|], + addArgument "arg" [t|Int|] + rpcFunction "noNothing" $ pure () - ] - ) + ) -$(makeRpc $ rpcApi "StreamExample" [ +$(makeRpc $ rpcApi "StreamExample" $ do rpcFunction "createMultiplyStream" $ do addStream "stream" [t|(Int, Int)|] [t|Int|] - , + rpcFunction "createStreams" $ do addStream "stream1" [t|Bool|] [t|Bool|] addStream "stream2" [t|Int|] [t|Int|] - ] - ) + + rpcObservable "intObservable" [t|Int|] + ) exampleProtocolImpl :: ExampleProtocolImpl exampleProtocolImpl = ExampleProtocolImpl { @@ -72,9 +78,9 @@ streamExampleProtocolImpl = StreamExampleProtocolImpl { createStreamsImpl } where - createMultiplyStreamImpl :: Stream Int (Int, Int) -> IO () + createMultiplyStreamImpl :: MonadIO m => Stream Int (Int, Int) -> m () createMultiplyStreamImpl stream = streamSetHandler stream $ \(x, y) -> streamSend stream (x * y) - createStreamsImpl :: Stream Bool Bool -> Stream Int Int -> IO () + createStreamsImpl :: MonadIO m => Stream Bool Bool -> Stream Int Int -> m () createStreamsImpl stream1 stream2 = do streamSetHandler stream1 $ streamSend stream1 streamSetHandler stream2 $ streamSend stream2 @@ -84,10 +90,10 @@ spec = parallel $ do describe "Example" $ do it "works" $ do withStandaloneClient @ExampleProtocol exampleProtocolImpl $ \client -> do - awaitResult (fixedHandler42 client 5) `shouldReturnAsync` False - awaitResult (fixedHandler42 client 42) `shouldReturnAsync` True - awaitResult (fixedHandlerInc client 41) `shouldReturnAsync` 42 - awaitResult (multiArgs client 10 3 False) `shouldReturnAsync` (13, True) + (awaitIO =<< fixedHandler42 client 5) `shouldReturn` False + (awaitIO =<< fixedHandler42 client 42) `shouldReturn` True + (awaitIO =<< fixedHandlerInc client 41) `shouldReturn` 42 + (awaitIO =<< multiArgs client 10 3 False) `shouldReturn` (13, True) noResponse client 1337 noNothing client