From ee61f057e2f723ee2464694e81ab344f72edd08c Mon Sep 17 00:00:00 2001 From: Jens Nolte <git@queezle.net> Date: Tue, 18 May 2021 22:57:21 +0200 Subject: [PATCH] Multiplexer: Implement channel management --- qrpc.cabal | 1 + src/Network/Rpc.hs | 23 +- src/Network/Rpc/Connection.hs | 7 +- src/Network/Rpc/Multiplexer.hs | 314 +++++++++++++++++++++------- test/Network/Rpc/MultiplexerSpec.hs | 56 +++++ 5 files changed, 310 insertions(+), 91 deletions(-) create mode 100644 test/Network/Rpc/MultiplexerSpec.hs diff --git a/qrpc.cabal b/qrpc.cabal index 472a53c..a996e5e 100644 --- a/qrpc.cabal +++ b/qrpc.cabal @@ -94,5 +94,6 @@ test-suite qrpc-test main-is: Spec.hs other-modules: Network.RpcSpec + Network.Rpc.MultiplexerSpec hs-source-dirs: test diff --git a/src/Network/Rpc.hs b/src/Network/Rpc.hs index 30f6296..0190ee6 100644 --- a/src/Network/Rpc.hs +++ b/src/Network/Rpc.hs @@ -2,7 +2,7 @@ module Network.Rpc where import Control.Concurrent (forkFinally) import Control.Concurrent.Async (Async, async, link, withAsync) -import Control.Exception (SomeException, bracket, bracketOnError, bracketOnError) +import Control.Exception (SomeException, bracket, bracketOnError, bracketOnError, interruptible) import Control.Monad ((>=>), when, forever) import Control.Monad.State (State, execState) import qualified Control.Monad.State as State @@ -12,14 +12,13 @@ import qualified Data.ByteString.Lazy as BSL import Data.Hashable (Hashable) import qualified Data.HashMap.Strict as HM import Data.Maybe (isNothing) -import Language.Haskell.TH +import Language.Haskell.TH hiding (interruptible) import Language.Haskell.TH.Syntax import Network.Rpc.Multiplexer import Network.Rpc.Connection import qualified Network.Socket as Socket import Prelude import GHC.Generics -import GHC.IO (unsafeUnmask) import System.Posix.Files (getFileStatus, isSocket) @@ -261,11 +260,11 @@ emptyClientState = ClientState { } clientSend :: RpcProtocol p => Client p -> ProtocolRequest p -> IO () -clientSend client req = channelSend_ client.channel (encode req) [] +clientSend client req = channelSend_ client.channel [] (encode req) clientRequestBlocking :: forall p. RpcProtocol p => Client p -> ProtocolRequest p -> IO (ProtocolResponse p) clientRequestBlocking client req = do resultMVar <- newEmptyMVar - channelSend client.channel (encode req) [] $ \msgId -> + channelSend client.channel [] (encode req) $ \msgId -> modifyMVar_ client.stateMVar $ \state -> pure state{callbacks = HM.insert msgId (requestCompletedCallback resultMVar msgId) state.callbacks} -- Block on resultMVar until the request completes @@ -305,7 +304,7 @@ serverHandleChannelMessage protocolImpl channel msgId headers msg = case decodeO serverHandleChannelRequest :: ProtocolRequest p -> IO () serverHandleChannelRequest req = handleMessage @p protocolImpl req >>= maybe (pure ()) serverSendResponse serverSendResponse :: ProtocolResponse p -> IO () - serverSendResponse response = channelSend_ channel (encode wrappedResponse) [] + serverSendResponse response = channelSend_ channel [] (encode wrappedResponse) where wrappedResponse :: ProtocolResponseWrapper p wrappedResponse = (msgId, response) @@ -337,11 +336,7 @@ withClient :: forall p a b. (IsConnection a, RpcProtocol p) => a -> (Client p -> withClient x = bracket (newClient x) clientClose newClient :: forall p a. (IsConnection a, RpcProtocol p) => a -> IO (Client p) -newClient x = do - clientMVar <- newEmptyMVar - -- 'runMultiplexerProtcol' needs to be interruptible (so it can terminate when it is closed), so 'unsafeUnmask' is used to ensure that this function also works when used in 'bracket' - link =<< async (unsafeUnmask (runMultiplexerProtocol (newChannelClient >=> putMVar clientMVar) (toSocketConnection x))) - takeMVar clientMVar +newClient x = newChannelClient =<< newMultiplexer (toSocketConnection x) newChannelClient :: RpcProtocol p => Channel -> IO (Client p) @@ -400,7 +395,7 @@ listenOnBoundSocket protocolImpl sock = do Socket.gracefulClose conn 2000 runServerHandler :: forall p a. (RpcProtocol p, HasProtocolImpl p, IsConnection a) => ProtocolImpl p -> a -> IO () -runServerHandler protocolImpl = runMultiplexerProtocol (registerChannelServerHandler @p protocolImpl) . toSocketConnection +runServerHandler protocolImpl = runMultiplexer (registerChannelServerHandler @p protocolImpl) . toSocketConnection -- ** Test implementation @@ -483,10 +478,10 @@ buildTupleType fields = buildTupleType' =<< fields go t (f:fs) = go (AppT t f) fs buildFunctionType :: Q [Type] -> Q Type -> Q Type -buildFunctionType argTypes pureType = go =<< argTypes +buildFunctionType argTypes returnType = go =<< argTypes where go :: [Type] -> Q Type - go [] = pureType + go [] = returnType go (t:ts) = pure t `funT` go ts defaultBangType :: Q Type -> Q BangType diff --git a/src/Network/Rpc/Connection.hs b/src/Network/Rpc/Connection.hs index 95f4d0d..1b0efca 100644 --- a/src/Network/Rpc/Connection.hs +++ b/src/Network/Rpc/Connection.hs @@ -3,12 +3,11 @@ module Network.Rpc.Connection where import Control.Concurrent (threadDelay) import Control.Concurrent.Async (Async, async, cancel, link, waitCatch, withAsync) import Control.Concurrent.MVar -import Control.Exception (Exception(..), SomeException, bracketOnError, finally, throwIO, bracketOnError, onException) +import Control.Exception (Exception(..), SomeException, bracketOnError, interruptible, finally, throwIO, bracketOnError, onException) import Control.Monad ((>=>), unless, forM_) import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BSL import Data.List (intercalate) -import GHC.IO (unsafeUnmask) import qualified Network.Socket as Socket import qualified Network.Socket.ByteString as Socket import qualified Network.Socket.ByteString.Lazy as SocketL @@ -37,6 +36,8 @@ newtype ConnectionFailed = ConnectionFailed [(Socket.AddrInfo, SomeException)] instance Exception ConnectionFailed where displayException (ConnectionFailed attemts) = "Connection attempts failed:\n" <> intercalate "\n" (map (\(addr, err) -> show (Socket.addrAddress addr) <> ": " <> displayException err) attemts) +-- | Open a TCP connection to target host and port. Will start multiple connection attempts (i.e. retry quickly and then try other addresses) but only return the first successful connection. +-- Throws a 'ConnectionFailed' on failure, which contains the exceptions from all failed connection attempts. connectTCP :: Socket.HostName -> Socket.ServiceName -> IO Socket.Socket connectTCP host port = do -- 'getAddrInfo' either pures a non-empty list or throws an exception @@ -74,7 +75,7 @@ connectTCP host port = do -- The 'raceConnections'-async is 'link'ed to this thread, so 'readMVar' is interrupted when all connection attempts fail sock <- - (withAsync (unsafeUnmask raceConnections) (link >=> const (readMVar sockMVar)) + (withAsync (interruptible raceConnections) (link >=> const (readMVar sockMVar)) `finally` (mapM_ (cancel . snd) =<< readMVar connectTasksMVar)) `onException` (mapM_ Socket.close =<< tryTakeMVar sockMVar) -- As soon as we have an open connection, stop spawning more connections diff --git a/src/Network/Rpc/Multiplexer.hs b/src/Network/Rpc/Multiplexer.hs index c17b214..e18a1ce 100644 --- a/src/Network/Rpc/Multiplexer.hs +++ b/src/Network/Rpc/Multiplexer.hs @@ -5,7 +5,7 @@ module Network.Rpc.Multiplexer ( Channel, MessageHeader(..), MessageHeaderResult(..), - -- TODO rename (this class only exists for unified error reporting and connection termination) + -- TODO rename (this class only exists for `reportProtocolError` and `reportLocalError`) HasMultiplexerProtocolWorker(..), reportProtocolError, reportLocalError, @@ -16,14 +16,17 @@ module Network.Rpc.Multiplexer ( ChannelMessageHandler, SimpleChannelMessageHandler, simpleMessageHandler, - runMultiplexerProtocol, + runMultiplexer, + newMultiplexer, ) where +import Control.Concurrent.Async (async, link) import Control.Concurrent (myThreadId, throwTo) -import Control.Exception (Exception(..), MaskingState(Unmasked), catch, finally, throwIO, getMaskingState) -import Control.Monad (when) +import Control.Exception (Exception(..), MaskingState(Unmasked), catch, finally, interruptible, throwIO, getMaskingState, mask_) +import Control.Monad (when, unless, void) import Control.Monad.IO.Class (liftIO) -import Control.Monad.State (StateT, execStateT, get, put) +import Control.Monad.State (StateT, execStateT, runStateT) +import qualified Control.Monad.State as State import Control.Concurrent.MVar import Data.Binary (Binary, encode) import qualified Data.Binary as Binary @@ -31,6 +34,7 @@ import Data.Binary.Get (Decoder(..), runGetIncremental, pushChunk, pushEndOfInpu import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BSL import qualified Data.HashMap.Strict as HM +import Data.Tuple (swap) import Data.Word import Network.Rpc.Connection import Prelude @@ -52,7 +56,9 @@ data MultiplexerProtocolMessage data MultiplexerProtocolMessageHeader = CreateChannel deriving (Binary, Generic, Show) -newtype MessageHeader = CreateChannelHeader (ChannelId -> IO ()) +newtype MessageHeader = + -- | The callback is running in a masked state and is blocking all network traffic. The callback should only be used to register a callback on the channel and to store it; then it should return immediately. + CreateChannelHeader (Channel -> IO ()) newtype MessageHeaderResult = CreateChannelHeaderResult Channel data MultiplexerProtocolWorker = MultiplexerProtocolWorker { @@ -63,7 +69,9 @@ data MultiplexerProtocolWorkerState = MultiplexerProtocolWorkerState { socketConnection :: Maybe Connection, channels :: HM.HashMap ChannelId Channel, sendChannel :: ChannelId, - receiveChannel :: ChannelId + receiveChannel :: ChannelId, + receiveNextChannelId :: ChannelId, + sendNextChannelId :: ChannelId } class HasMultiplexerProtocolWorker a where @@ -75,8 +83,50 @@ data NotConnected = NotConnected deriving Show instance Exception NotConnected -runMultiplexerProtocol :: (Channel -> IO ()) -> Connection -> IO () -runMultiplexerProtocol channelSetupHook connection = do + +data Channel = Channel { + channelId :: ChannelId, + worker :: MultiplexerProtocolWorker, + stateMVar :: MVar ChannelState, + sendStateMVar :: MVar ChannelSendState, + receiveStateMVar :: MVar ChannelReceiveState +} +instance HasMultiplexerProtocolWorker Channel where + getMultiplexerProtocolWorker = (.worker) +data ChannelState = ChannelState { + connectionState :: ChannelConnectivity, + children :: [Channel] +} +newtype ChannelSendState = ChannelSendState { + nextMessageId :: MessageId +} +data ChannelReceiveState = ChannelReceiveState { + nextMessageId :: MessageId, + handler :: ChannelMessageHandler +} + +data ChannelConnectivity = Connected | Closed | CloseConfirmed + deriving (Eq, Show) + +data ChannelNotConnected = ChannelNotConnected + deriving Show +instance Exception ChannelNotConnected + +type ChannelMessageHandler = MessageId -> [MessageHeaderResult] -> Decoder (IO ()) +type SimpleChannelMessageHandler = MessageId -> [MessageHeaderResult] -> BSL.ByteString -> IO () + + +-- | Starts a new multiplexer on an existing connection. +-- This starts a thread which runs until 'channelClose' is called on the resulting 'Channel' (use e.g. 'bracket' to ensure the channel is closed). +newMultiplexer :: forall a. (IsConnection a) => a -> IO Channel +newMultiplexer x = do + channelMVar <- newEmptyMVar + -- 'runMultiplexerProtcol' needs to be interruptible (so it can terminate when it is closed), so 'interruptible' is used to ensure that this function also works when used in 'bracket' + mask_ $ link =<< async (interruptible (runMultiplexer (putMVar channelMVar) (toSocketConnection x))) + takeMVar channelMVar + +runMultiplexer :: (Channel -> IO ()) -> Connection -> IO () +runMultiplexer channelSetupHook connection = do -- Running in masked state, this thread (running the receive-function) cannot be interrupted when closing the connection maskingState <- getMaskingState when (maskingState /= Unmasked) (fail "'runMultiplexerProtocol' cannot run in masked thread state.") @@ -89,13 +139,15 @@ runMultiplexerProtocol channelSetupHook connection = do socketConnection = Just connection, channels = HM.empty, sendChannel = 0, - receiveChannel = 0 + receiveChannel = 0, + receiveNextChannelId = undefined, + sendNextChannelId = undefined } let worker = MultiplexerProtocolWorker { stateMVar, killReceiverMVar } - (((channelSetupHook =<< newChannel worker 0) >> multiplexerProtocolReceive worker) + (((channelSetupHook =<< newChannel worker 0 Connected) >> multiplexerProtocolReceive worker) `finally` (disarmKillReciver >> multiplexerConnectionClose worker)) `catch` (\(_ex :: NotConnected) -> pure ()) @@ -117,14 +169,29 @@ multiplexerProtocolReceive worker = receiveThreadLoop multiplexerDecoder Just channel -> handleChannelMessage channel headers len Nothing -> liftIO $ reportProtocolError worker ("Received message on invalid channel: " <> show workerState.receiveChannel) handleMultiplexerMessage (SwitchChannel channelId) = liftIO $ modifyMVar_ worker.stateMVar $ \state -> pure state{receiveChannel=channelId} + handleMultiplexerMessage CloseChannel = liftIO $ do + workerState <- readMVar worker.stateMVar + case HM.lookup workerState.receiveChannel workerState.channels of + Just channel -> channelConfirmClose channel + Nothing -> reportProtocolError worker ("Received CloseChannel on invalid channel: " <> show workerState.receiveChannel) handleMultiplexerMessage x = liftIO $ print x >> undefined -- Unhandled multiplexer message handleChannelMessage :: Channel -> [MultiplexerProtocolMessageHeader] -> MessageLength -> StateT BS.ByteString IO () handleChannelMessage channel headers len = do - headerResults <- liftIO $ sequence (processHeader <$> headers) - decoder <- liftIO $ channelStartHandleMessage channel headerResults + decoder <- liftIO $ do + -- Don't receive messages on closed channels + channelState <- readMVar channel.stateMVar + case channelState.connectionState of + Connected -> do + headerResults <- sequence (processHeader <$> headers) + channelStartHandleMessage channel headerResults + -- The channel is closed but the remote might not know that yet, so the message is silently ignored + Closed -> closedChannelMessageHandler <$> sequence (processHeader <$> headers) + -- This might only be reached in some edge cases, as a closed channel will be removed from the channel map after the close is confirmed. + CloseConfirmed -> reportProtocolError worker ("Received message on channel " <> show channel.channelId <> " after receiving a close confirmation for that channel") + -- StateT currently contains leftovers - initialLeftovers <- get + initialLeftovers <- State.get let leftoversLength = fromIntegral $ BS.length initialLeftovers remaining = len - leftoversLength @@ -132,7 +199,7 @@ multiplexerProtocolReceive worker = receiveThreadLoop multiplexerDecoder (channelCallback, leftovers) <- liftIO $ runDecoder remaining (pushChunk decoder initialLeftovers) -- Data is received in chunks but messages have a defined length, so leftovers are put back into StateT - put leftovers + State.put leftovers -- Critical section: don't interrupt downstream callbacks liftIO $ withMVar worker.killReceiverMVar $ const channelCallback where @@ -161,44 +228,145 @@ multiplexerProtocolReceive worker = receiveThreadLoop multiplexerDecoder failedToTerminate :: IO a failedToTerminate = reportLocalError worker ("Decoder on channel " <> show channel.channelId <> " failed to terminate after end-of-input") processHeader :: MultiplexerProtocolMessageHeader -> IO MessageHeaderResult - processHeader CreateChannel = undefined + processHeader CreateChannel = do + channelId <- modifyMVar worker.stateMVar $ \workerState -> do + let + receiveNextChannelId = workerState.receiveNextChannelId + newWorkerState = workerState{receiveNextChannelId = receiveNextChannelId + 2} + pure (newWorkerState, receiveNextChannelId) + modifyMVar channel.stateMVar $ \state -> do + createdChannel <- newSubChannel channel.worker channelId channel + let newState = state{ + children = createdChannel : state.children + } + pure (newState, CreateChannelHeaderResult createdChannel) receiveThrowing :: IO BS.ByteString receiveThrowing = do state <- readMVar worker.stateMVar maybe (throwIO NotConnected) (.receive) state.socketConnection +closedChannelMessageHandler :: [MessageHeaderResult] -> Decoder (IO ()) +closedChannelMessageHandler headers = discardMessageDecoder $ mapM_ handleHeader headers + where + handleHeader :: MessageHeaderResult -> IO () + handleHeader (CreateChannelHeaderResult createdChannel) = + -- The channel that received the message is already closed, so newly created children are implicitly closed as well + modifyMVar_ createdChannel.stateMVar $ \state -> + pure state{connectionState = Closed} + + discardMessageDecoder :: IO () -> Decoder (IO ()) + discardMessageDecoder action = Partial (maybe done partial) + where + partial :: BS.ByteString -> Decoder (IO ()) + partial = const (discardMessageDecoder action) + done :: Decoder (IO ()) + done = Done "" 0 action + +withMultiplexerState :: MultiplexerProtocolWorker -> StateT MultiplexerProtocolWorkerState IO a -> IO a +withMultiplexerState worker action = modifyMVar worker.stateMVar $ fmap swap . runStateT action + multiplexerSend :: MultiplexerProtocolWorker -> MultiplexerProtocolMessage -> IO () -multiplexerSend worker msg = withMVar worker.stateMVar $ \state -> multiplexerStateSend state msg +multiplexerSend worker msg = withMultiplexerState worker (multiplexerStateSend msg) -multiplexerStateSend :: MultiplexerProtocolWorkerState -> MultiplexerProtocolMessage -> IO () -multiplexerStateSend state = multiplexerStateSendRaw state . encode +multiplexerStateSend :: MultiplexerProtocolMessage -> StateT MultiplexerProtocolWorkerState IO () +multiplexerStateSend = multiplexerStateSendRaw . encode -multiplexerStateSendRaw :: MultiplexerProtocolWorkerState -> BSL.ByteString -> IO () -multiplexerStateSendRaw MultiplexerProtocolWorkerState{socketConnection=Just connection} rawMsg = connection.send rawMsg -multiplexerStateSendRaw MultiplexerProtocolWorkerState{socketConnection=Nothing} _ = throwIO NotConnected +multiplexerStateSendRaw :: BSL.ByteString -> StateT MultiplexerProtocolWorkerState IO () +multiplexerStateSendRaw rawMsg = do + state <- State.get + case state.socketConnection of + Nothing -> liftIO $ throwIO NotConnected + Just connection -> liftIO $ connection.send rawMsg -multiplexerSendChannelMessage :: MultiplexerProtocolWorker -> ChannelId -> BSL.ByteString -> [MessageHeader] -> IO () -multiplexerSendChannelMessage worker channelId msg headers = do +multiplexerSendChannelMessage :: Channel -> BSL.ByteString -> [MessageHeader] -> IO () +multiplexerSendChannelMessage channel msg headers = do -- Sending a channel message consists of multiple low-level send operations, so the MVar is held during the operation - modifyMVar_ worker.stateMVar $ \state -> do - -- Switch to the specified channel (if required) - when (state.sendChannel /= channelId) $ multiplexerSend worker (SwitchChannel channelId) + withMultiplexerState worker $ do + multiplexerSwitchChannel channel.channelId headerMessages <- sequence (prepareHeader <$> headers) - multiplexerStateSend state (ChannelMessage headerMessages (fromIntegral (BSL.length msg))) - multiplexerStateSendRaw state msg - pure state{sendChannel=channelId} - where - prepareHeader :: MessageHeader -> IO MultiplexerProtocolMessageHeader - prepareHeader (CreateChannelHeader _newChannelCallback) = undefined - -multiplexerChannelClose :: MultiplexerProtocolWorker -> ChannelId -> IO () -multiplexerChannelClose worker channelId = - if channelId == 0 - then multiplexerClose worker - else undefined + multiplexerStateSend (ChannelMessage headerMessages (fromIntegral (BSL.length msg))) + multiplexerStateSendRaw msg + where + worker :: MultiplexerProtocolWorker + worker = channel.worker + prepareHeader :: MessageHeader -> StateT MultiplexerProtocolWorkerState IO MultiplexerProtocolMessageHeader + prepareHeader (CreateChannelHeader newChannelCallback) = do + nextChannelId <- State.state (\state -> (state.sendNextChannelId, state{sendNextChannelId = state.sendNextChannelId + 1})) + createdChannel <- liftIO $ newSubChannel worker nextChannelId channel + + -- TODO we probably don't want to call the callback here, as the state is locked; we also don't want to call it later, because at that point messages could already arrive and the handler has to be set + -- TODO: also we are currently holding the MultiplexerProtocolWorkerState which means sending messages from the callback will result in a deadlock - calling code must not do that. That's also an indication for a bad design + -- TODO the current design requires the caller to use mvars/iorefs to get the created channel - also not optimal. + liftIO $ newChannelCallback createdChannel + pure CreateChannel + +multiplexerSwitchChannel :: ChannelId -> StateT MultiplexerProtocolWorkerState IO () +multiplexerSwitchChannel channelId = do + -- Check if channel switch is required and update current channel + shouldSwitchChannel <- State.state (\state -> (state.sendChannel /= channelId, state{sendChannel = channelId})) + when shouldSwitchChannel $ multiplexerStateSend (SwitchChannel channelId) + +-- | Closes a channel and all it's children: After the function completes, the channel callback will no longer be called on received messages and sending messages on the channel will fail. +-- Calling close on a closed channel is a noop. +channelClose :: Channel -> IO () +channelClose channel = do + -- Change channel state of all unclosed channels in the tree to closed + channelWasClosed <- channelClose' channel + + when channelWasClosed $ do + -- Send close message + withMultiplexerState channel.worker $ do + multiplexerSwitchChannel channel.channelId + multiplexerStateSend CloseChannel + + -- Terminate the worker when the root channel is closed + when (channel.channelId == 0) $ multiplexerClose channel.worker + where + channelClose' :: Channel -> IO Bool + channelClose' chan = modifyMVar chan.stateMVar $ \state -> + case state.connectionState of + Connected -> do + -- Close all children while blocking the state. This prevents children from receiving a messages after the parent channel has already rejected a message + liftIO (mapM_ (void . channelClose') state.children) + pure (state{connectionState = Closed}, True) + -- Channel was already closed and can be ignored + Closed -> pure (state, False) + CloseConfirmed -> pure (state, False) + +-- Called on a channel when a ChannelClose message is received +channelConfirmClose :: Channel -> IO () +channelConfirmClose channel = do + closeConfirmedIds <- channelClose' channel + + -- List can only be empty when the channel was already confirmed as closed + unless (closeConfirmedIds == []) $ do + -- Remote channels from worker + withMultiplexerState channel.worker $ do + State.modify $ \state -> state{channels = foldr HM.delete state.channels closeConfirmedIds} + + -- Terminate the worker when the root channel is closed + when (channel.channelId == 0) $ multiplexerClose channel.worker + where + channelClose' :: Channel -> IO [ChannelId] + channelClose' chan = modifyMVar chan.stateMVar $ \state -> + case state.connectionState of + Connected -> do + closedIdLists <- liftIO (mapM channelClose' state.children) + let + closedIds = chan.channelId : mconcat closedIdLists + newState = state{connectionState = CloseConfirmed} + pure (newState, closedIds) + Closed -> do + closedIdLists <- liftIO (mapM channelClose' state.children) + let + closedIds = chan.channelId : mconcat closedIdLists + newState = state{connectionState = CloseConfirmed} + pure (newState, closedIds) + -- Ignore already closed children + CloseConfirmed -> pure (state, []) -- | Close a mulxiplexer worker by closing the connection it is based on and then stopping the worker thread. multiplexerClose :: MultiplexerProtocolWorker -> IO () @@ -206,6 +374,7 @@ multiplexerClose worker = do multiplexerConnectionClose worker modifyMVar_ worker.killReceiverMVar $ \killReceiver -> do killReceiver + -- Replace 'killReceiver'-action with a no-op to ensure it runs only once pure (pure ()) -- | Internal close operation: Closes the communication channel a multiplexer is operating on. The caller has the responsibility to ensure the receiver thread is closed. @@ -221,9 +390,7 @@ multiplexerConnectionClose worker = do reportProtocolError :: HasMultiplexerProtocolWorker a => a -> String -> IO b reportProtocolError hasWorker message = do let worker = getMultiplexerProtocolWorker hasWorker - modifyMVar_ worker.stateMVar $ \state -> do - multiplexerStateSend state $ ProtocolError message - pure state + multiplexerSend worker $ ProtocolError message -- TODO custom error type, close connection undefined @@ -231,31 +398,10 @@ reportLocalError :: HasMultiplexerProtocolWorker a => a -> String -> IO b reportLocalError hasWorker message = do hPutStrLn stderr message let worker = getMultiplexerProtocolWorker hasWorker - modifyMVar_ worker.stateMVar $ \state -> do - multiplexerStateSend state $ ProtocolError "Internal server error" - pure state + multiplexerSend worker $ ProtocolError "Internal server error" -- TODO custom error type, close connection undefined -data Channel = Channel { - channelId :: ChannelId, - worker :: MultiplexerProtocolWorker, - sendStateMVar :: MVar ChannelSendState, - receiveStateMVar :: MVar ChannelReceiveState -} -instance HasMultiplexerProtocolWorker Channel where - getMultiplexerProtocolWorker = (.worker) -newtype ChannelSendState = ChannelSendState { - nextMessageId :: MessageId -} -data ChannelReceiveState = ChannelReceiveState { - nextMessageId :: MessageId, - handler :: ChannelMessageHandler -} - -type ChannelMessageHandler = MessageId -> [MessageHeaderResult] -> Decoder (IO ()) -type SimpleChannelMessageHandler = MessageId -> [MessageHeaderResult] -> BSL.ByteString -> IO () - simpleMessageHandler :: SimpleChannelMessageHandler -> ChannelMessageHandler simpleMessageHandler handler msgId headers = decoder "" where @@ -267,10 +413,23 @@ simpleMessageHandler handler msgId headers = decoder "" done :: Decoder (IO ()) done = Done "" (BSL.length acc) (handler msgId headers acc) - --- Should not be exported -newChannel :: MultiplexerProtocolWorker -> ChannelId -> IO Channel -newChannel worker channelId = do +newSubChannel :: MultiplexerProtocolWorker -> ChannelId -> Channel -> IO Channel +newSubChannel worker channelId parent = + modifyMVar parent.stateMVar $ \parentChannelState -> do + -- Holding the parents channelState while initializing the channel will ensure the ChannelConnectivity is inherited atomically + createdChannel <- newChannel worker channelId parentChannelState.connectionState + + let newParentState = parentChannelState{ + children = createdChannel : parentChannelState.children + } + pure (newParentState, createdChannel) + +newChannel :: MultiplexerProtocolWorker -> ChannelId -> ChannelConnectivity -> IO Channel +newChannel worker channelId connectionState = do + stateMVar <- newMVar ChannelState { + connectionState, + children = [] + } sendStateMVar <- newMVar ChannelSendState { nextMessageId = 0 } @@ -282,25 +441,32 @@ newChannel worker channelId = do let channel = Channel { worker, channelId, + stateMVar, sendStateMVar, receiveStateMVar } modifyMVar_ worker.stateMVar $ \state -> pure state{channels = HM.insert channelId channel state.channels} pure channel -channelSend :: Channel -> BSL.ByteString -> [MessageHeader] -> (MessageId -> IO ()) -> IO () + +channelSend :: Channel -> [MessageHeader] -> BSL.ByteString -> (MessageId -> IO ()) -> IO () channelSend channel msg headers callback = do modifyMVar_ channel.sendStateMVar $ \state -> do + -- Don't send on closed channels + channelState <- readMVar channel.stateMVar + unless (channelState.connectionState == Connected) $ throwIO ChannelNotConnected + callback state.nextMessageId - multiplexerSendChannelMessage channel.worker channel.channelId msg headers + multiplexerSendChannelMessage channel headers msg pure state{nextMessageId = state.nextMessageId + 1} -channelSend_ :: Channel -> BSL.ByteString -> [MessageHeader] -> IO () -channelSend_ channel msg headers = channelSend channel msg headers (const (pure ())) -channelClose :: Channel -> IO () -channelClose channel = multiplexerChannelClose channel.worker channel.channelId + +channelSend_ :: Channel -> [MessageHeader] -> BSL.ByteString -> IO () +channelSend_ channel headers msg = channelSend channel headers msg (const (pure ())) + channelStartHandleMessage :: Channel -> [MessageHeaderResult] -> IO (Decoder (IO ())) channelStartHandleMessage channel headers = do (msgId, handler) <- modifyMVar channel.receiveStateMVar $ \state -> pure (state{nextMessageId = state.nextMessageId + 1}, (state.nextMessageId, state.handler)) pure (handler msgId headers) + channelSetHandler :: Channel -> ChannelMessageHandler -> IO () channelSetHandler channel handler = modifyMVar_ channel.receiveStateMVar $ \state -> pure state{handler} diff --git a/test/Network/Rpc/MultiplexerSpec.hs b/test/Network/Rpc/MultiplexerSpec.hs new file mode 100644 index 0000000..95dced1 --- /dev/null +++ b/test/Network/Rpc/MultiplexerSpec.hs @@ -0,0 +1,56 @@ +module Network.Rpc.MultiplexerSpec where + +import Control.Concurrent.Async (concurrently_) +import Control.Concurrent.MVar +import Control.Exception (bracket, mask_) +import Prelude +import Network.Rpc.Multiplexer +import Network.Rpc.Connection +import Test.Hspec + +spec :: Spec +spec = describe "runMultiplexerProtocol" $ parallel $ do + it "can be closed from the channelSetupHook" $ do + (x, _) <- newDummySocketPair + runMultiplexer channelClose x + + it "fails when run in masked state" $ do + (x, _) <- newDummySocketPair + mask_ $ runMultiplexer channelClose x `shouldThrow` anyException + + it "closes when the remote is closed" $ do + (x, y) <- newDummySocketPair + concurrently_ + (runMultiplexer (const (pure ())) x) + (runMultiplexer channelClose y) + + it "it can send and receive simple messages" $ do + withEchoServer $ \channel -> do + recvMVar <- newEmptyMVar + channelSetHandler channel $ simpleMessageHandler $ \_ _ -> putMVar recvMVar + channelSend_ channel [] "foobar" + takeMVar recvMVar `shouldReturn` "foobar" + channelSend_ channel [] "test" + takeMVar recvMVar `shouldReturn` "test" + + +withEchoServer :: (Channel -> IO a) -> IO a +withEchoServer fn = bracket setup close (\(channel, _) -> fn channel) + where + setup :: IO (Channel, Channel) + setup = do + (x, y) <- newDummySocketPair + echoChannel <- newMultiplexer y + configureEchoHandler echoChannel + mainChannel <- newMultiplexer x + pure (mainChannel, echoChannel) + close :: (Channel, Channel) -> IO () + close (x, y) = channelClose x >> channelClose y + configureEchoHandler :: Channel -> IO () + configureEchoHandler channel = channelSetHandler channel (echoHandler channel) + echoHandler :: Channel -> ChannelMessageHandler + echoHandler channel = simpleMessageHandler $ \_msgId headers msg -> do + mapM_ echoHeaderHandler headers + channelSend_ channel [] msg + echoHeaderHandler :: MessageHeaderResult -> IO () + echoHeaderHandler (CreateChannelHeaderResult channel) = configureEchoHandler channel -- GitLab