From ee61f057e2f723ee2464694e81ab344f72edd08c Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Tue, 18 May 2021 22:57:21 +0200
Subject: [PATCH] Multiplexer: Implement channel management

---
 qrpc.cabal                          |   1 +
 src/Network/Rpc.hs                  |  23 +-
 src/Network/Rpc/Connection.hs       |   7 +-
 src/Network/Rpc/Multiplexer.hs      | 314 +++++++++++++++++++++-------
 test/Network/Rpc/MultiplexerSpec.hs |  56 +++++
 5 files changed, 310 insertions(+), 91 deletions(-)
 create mode 100644 test/Network/Rpc/MultiplexerSpec.hs

diff --git a/qrpc.cabal b/qrpc.cabal
index 472a53c..a996e5e 100644
--- a/qrpc.cabal
+++ b/qrpc.cabal
@@ -94,5 +94,6 @@ test-suite qrpc-test
   main-is: Spec.hs
   other-modules:
     Network.RpcSpec
+    Network.Rpc.MultiplexerSpec
   hs-source-dirs:
     test
diff --git a/src/Network/Rpc.hs b/src/Network/Rpc.hs
index 30f6296..0190ee6 100644
--- a/src/Network/Rpc.hs
+++ b/src/Network/Rpc.hs
@@ -2,7 +2,7 @@ module Network.Rpc where
 
 import Control.Concurrent (forkFinally)
 import Control.Concurrent.Async (Async, async, link, withAsync)
-import Control.Exception (SomeException, bracket, bracketOnError, bracketOnError)
+import Control.Exception (SomeException, bracket, bracketOnError, bracketOnError, interruptible)
 import Control.Monad ((>=>), when, forever)
 import Control.Monad.State (State, execState)
 import qualified Control.Monad.State as State
@@ -12,14 +12,13 @@ import qualified Data.ByteString.Lazy as BSL
 import Data.Hashable (Hashable)
 import qualified Data.HashMap.Strict as HM
 import Data.Maybe (isNothing)
-import Language.Haskell.TH
+import Language.Haskell.TH hiding (interruptible)
 import Language.Haskell.TH.Syntax
 import Network.Rpc.Multiplexer
 import Network.Rpc.Connection
 import qualified Network.Socket as Socket
 import Prelude
 import GHC.Generics
-import GHC.IO (unsafeUnmask)
 import System.Posix.Files (getFileStatus, isSocket)
 
 
@@ -261,11 +260,11 @@ emptyClientState = ClientState {
 }
 
 clientSend :: RpcProtocol p => Client p -> ProtocolRequest p -> IO ()
-clientSend client req = channelSend_ client.channel (encode req) []
+clientSend client req = channelSend_ client.channel [] (encode req)
 clientRequestBlocking :: forall p. RpcProtocol p => Client p -> ProtocolRequest p -> IO (ProtocolResponse p)
 clientRequestBlocking client req = do
   resultMVar <- newEmptyMVar
-  channelSend client.channel (encode req) [] $ \msgId ->
+  channelSend client.channel [] (encode req) $ \msgId ->
     modifyMVar_ client.stateMVar $
       \state -> pure state{callbacks = HM.insert msgId (requestCompletedCallback resultMVar msgId) state.callbacks}
   -- Block on resultMVar until the request completes
@@ -305,7 +304,7 @@ serverHandleChannelMessage protocolImpl channel msgId headers msg = case decodeO
     serverHandleChannelRequest :: ProtocolRequest p -> IO ()
     serverHandleChannelRequest req = handleMessage @p protocolImpl req >>= maybe (pure ()) serverSendResponse
     serverSendResponse :: ProtocolResponse p -> IO ()
-    serverSendResponse response = channelSend_ channel (encode wrappedResponse) []
+    serverSendResponse response = channelSend_ channel [] (encode wrappedResponse)
       where
         wrappedResponse :: ProtocolResponseWrapper p
         wrappedResponse = (msgId, response)
@@ -337,11 +336,7 @@ withClient :: forall p a b. (IsConnection a, RpcProtocol p) => a -> (Client p ->
 withClient x = bracket (newClient x) clientClose
 
 newClient :: forall p a. (IsConnection a, RpcProtocol p) => a -> IO (Client p)
-newClient x = do
-  clientMVar <- newEmptyMVar
-  -- 'runMultiplexerProtcol' needs to be interruptible (so it can terminate when it is closed), so 'unsafeUnmask' is used to ensure that this function also works when used in 'bracket'
-  link =<< async (unsafeUnmask (runMultiplexerProtocol (newChannelClient >=> putMVar clientMVar) (toSocketConnection x)))
-  takeMVar clientMVar
+newClient x = newChannelClient =<< newMultiplexer (toSocketConnection x)
 
 
 newChannelClient :: RpcProtocol p => Channel -> IO (Client p)
@@ -400,7 +395,7 @@ listenOnBoundSocket protocolImpl sock = do
       Socket.gracefulClose conn 2000
 
 runServerHandler :: forall p a. (RpcProtocol p, HasProtocolImpl p, IsConnection a) => ProtocolImpl p -> a -> IO ()
-runServerHandler protocolImpl = runMultiplexerProtocol (registerChannelServerHandler @p protocolImpl) . toSocketConnection
+runServerHandler protocolImpl = runMultiplexer (registerChannelServerHandler @p protocolImpl) . toSocketConnection
 
 
 -- ** Test implementation
@@ -483,10 +478,10 @@ buildTupleType fields = buildTupleType' =<< fields
     go t (f:fs) = go (AppT t f) fs
 
 buildFunctionType :: Q [Type] -> Q Type -> Q Type
-buildFunctionType argTypes pureType = go =<< argTypes
+buildFunctionType argTypes returnType = go =<< argTypes
   where
     go :: [Type] -> Q Type
-    go [] = pureType
+    go [] = returnType
     go (t:ts) = pure t `funT` go ts
 
 defaultBangType  :: Q Type -> Q BangType
diff --git a/src/Network/Rpc/Connection.hs b/src/Network/Rpc/Connection.hs
index 95f4d0d..1b0efca 100644
--- a/src/Network/Rpc/Connection.hs
+++ b/src/Network/Rpc/Connection.hs
@@ -3,12 +3,11 @@ module Network.Rpc.Connection where
 import Control.Concurrent (threadDelay)
 import Control.Concurrent.Async (Async, async, cancel, link, waitCatch, withAsync)
 import Control.Concurrent.MVar
-import Control.Exception (Exception(..), SomeException, bracketOnError, finally, throwIO, bracketOnError, onException)
+import Control.Exception (Exception(..), SomeException, bracketOnError, interruptible, finally, throwIO, bracketOnError, onException)
 import Control.Monad ((>=>), unless, forM_)
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Lazy as BSL
 import Data.List (intercalate)
-import GHC.IO (unsafeUnmask)
 import qualified Network.Socket as Socket
 import qualified Network.Socket.ByteString as Socket
 import qualified Network.Socket.ByteString.Lazy as SocketL
@@ -37,6 +36,8 @@ newtype ConnectionFailed = ConnectionFailed [(Socket.AddrInfo, SomeException)]
 instance Exception ConnectionFailed where
   displayException (ConnectionFailed attemts) = "Connection attempts failed:\n" <> intercalate "\n" (map (\(addr, err) -> show (Socket.addrAddress addr) <> ": " <> displayException err) attemts)
 
+-- | Open a TCP connection to target host and port. Will start multiple connection attempts (i.e. retry quickly and then try other addresses) but only return the first successful connection.
+-- Throws a 'ConnectionFailed' on failure, which contains the exceptions from all failed connection attempts.
 connectTCP :: Socket.HostName -> Socket.ServiceName -> IO Socket.Socket
 connectTCP host port = do
   -- 'getAddrInfo' either pures a non-empty list or throws an exception
@@ -74,7 +75,7 @@ connectTCP host port = do
 
   -- The 'raceConnections'-async is 'link'ed to this thread, so 'readMVar' is interrupted when all connection attempts fail
   sock <-
-    (withAsync (unsafeUnmask raceConnections) (link >=> const (readMVar sockMVar))
+    (withAsync (interruptible raceConnections) (link >=> const (readMVar sockMVar))
       `finally` (mapM_ (cancel . snd) =<< readMVar connectTasksMVar))
         `onException` (mapM_ Socket.close =<< tryTakeMVar sockMVar)
     -- As soon as we have an open connection, stop spawning more connections
diff --git a/src/Network/Rpc/Multiplexer.hs b/src/Network/Rpc/Multiplexer.hs
index c17b214..e18a1ce 100644
--- a/src/Network/Rpc/Multiplexer.hs
+++ b/src/Network/Rpc/Multiplexer.hs
@@ -5,7 +5,7 @@ module Network.Rpc.Multiplexer (
   Channel,
   MessageHeader(..),
   MessageHeaderResult(..),
-  -- TODO rename (this class only exists for unified error reporting and connection termination)
+  -- TODO rename (this class only exists for `reportProtocolError` and `reportLocalError`)
   HasMultiplexerProtocolWorker(..),
   reportProtocolError,
   reportLocalError,
@@ -16,14 +16,17 @@ module Network.Rpc.Multiplexer (
   ChannelMessageHandler,
   SimpleChannelMessageHandler,
   simpleMessageHandler,
-  runMultiplexerProtocol,
+  runMultiplexer,
+  newMultiplexer,
 ) where
 
+import Control.Concurrent.Async (async, link)
 import Control.Concurrent (myThreadId, throwTo)
-import Control.Exception (Exception(..), MaskingState(Unmasked), catch, finally, throwIO, getMaskingState)
-import Control.Monad (when)
+import Control.Exception (Exception(..), MaskingState(Unmasked), catch, finally, interruptible, throwIO, getMaskingState, mask_)
+import Control.Monad (when, unless, void)
 import Control.Monad.IO.Class (liftIO)
-import Control.Monad.State (StateT, execStateT, get, put)
+import Control.Monad.State (StateT, execStateT, runStateT)
+import qualified Control.Monad.State as State
 import Control.Concurrent.MVar
 import Data.Binary (Binary, encode)
 import qualified Data.Binary as Binary
@@ -31,6 +34,7 @@ import Data.Binary.Get (Decoder(..), runGetIncremental, pushChunk, pushEndOfInpu
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Lazy as BSL
 import qualified Data.HashMap.Strict as HM
+import Data.Tuple (swap)
 import Data.Word
 import Network.Rpc.Connection
 import Prelude
@@ -52,7 +56,9 @@ data MultiplexerProtocolMessage
 data MultiplexerProtocolMessageHeader = CreateChannel
   deriving (Binary, Generic, Show)
 
-newtype MessageHeader = CreateChannelHeader (ChannelId -> IO ())
+newtype MessageHeader =
+  -- | The callback is running in a masked state and is blocking all network traffic. The callback should only be used to register a callback on the channel and to store it; then it should return immediately.
+  CreateChannelHeader (Channel -> IO ())
 newtype MessageHeaderResult = CreateChannelHeaderResult Channel
 
 data MultiplexerProtocolWorker = MultiplexerProtocolWorker {
@@ -63,7 +69,9 @@ data MultiplexerProtocolWorkerState = MultiplexerProtocolWorkerState {
   socketConnection :: Maybe Connection,
   channels :: HM.HashMap ChannelId Channel,
   sendChannel :: ChannelId,
-  receiveChannel :: ChannelId
+  receiveChannel :: ChannelId,
+  receiveNextChannelId :: ChannelId,
+  sendNextChannelId :: ChannelId
 }
 
 class HasMultiplexerProtocolWorker a where
@@ -75,8 +83,50 @@ data NotConnected = NotConnected
   deriving Show
 instance Exception NotConnected
 
-runMultiplexerProtocol :: (Channel -> IO ()) -> Connection -> IO ()
-runMultiplexerProtocol channelSetupHook connection = do
+
+data Channel = Channel {
+  channelId :: ChannelId,
+  worker :: MultiplexerProtocolWorker,
+  stateMVar :: MVar ChannelState,
+  sendStateMVar :: MVar ChannelSendState,
+  receiveStateMVar :: MVar ChannelReceiveState
+}
+instance HasMultiplexerProtocolWorker Channel where
+  getMultiplexerProtocolWorker = (.worker)
+data ChannelState = ChannelState {
+  connectionState :: ChannelConnectivity,
+  children :: [Channel]
+}
+newtype ChannelSendState = ChannelSendState {
+  nextMessageId :: MessageId
+}
+data ChannelReceiveState = ChannelReceiveState {
+  nextMessageId :: MessageId,
+  handler :: ChannelMessageHandler
+}
+
+data ChannelConnectivity = Connected | Closed | CloseConfirmed
+  deriving (Eq, Show)
+
+data ChannelNotConnected = ChannelNotConnected
+  deriving Show
+instance Exception ChannelNotConnected
+
+type ChannelMessageHandler = MessageId -> [MessageHeaderResult] -> Decoder (IO ())
+type SimpleChannelMessageHandler = MessageId -> [MessageHeaderResult] -> BSL.ByteString -> IO ()
+
+
+-- | Starts a new multiplexer on an existing connection.
+-- This starts a thread which runs until 'channelClose' is called on the resulting 'Channel' (use e.g. 'bracket' to ensure the channel is closed).
+newMultiplexer :: forall a. (IsConnection a) => a -> IO Channel
+newMultiplexer x = do
+  channelMVar <- newEmptyMVar
+  -- 'runMultiplexerProtcol' needs to be interruptible (so it can terminate when it is closed), so 'interruptible' is used to ensure that this function also works when used in 'bracket'
+  mask_ $ link =<< async (interruptible (runMultiplexer (putMVar channelMVar) (toSocketConnection x)))
+  takeMVar channelMVar
+
+runMultiplexer :: (Channel -> IO ()) -> Connection -> IO ()
+runMultiplexer channelSetupHook connection = do
   -- Running in masked state, this thread (running the receive-function) cannot be interrupted when closing the connection
   maskingState <- getMaskingState
   when (maskingState /= Unmasked) (fail "'runMultiplexerProtocol' cannot run in masked thread state.")
@@ -89,13 +139,15 @@ runMultiplexerProtocol channelSetupHook connection = do
     socketConnection = Just connection,
     channels = HM.empty,
     sendChannel = 0,
-    receiveChannel = 0
+    receiveChannel = 0,
+    receiveNextChannelId = undefined,
+    sendNextChannelId = undefined
   }
   let worker = MultiplexerProtocolWorker {
     stateMVar,
     killReceiverMVar
   }
-  (((channelSetupHook =<< newChannel worker 0) >> multiplexerProtocolReceive worker)
+  (((channelSetupHook =<< newChannel worker 0 Connected) >> multiplexerProtocolReceive worker)
     `finally` (disarmKillReciver >> multiplexerConnectionClose worker))
       `catch` (\(_ex :: NotConnected) -> pure ())
 
@@ -117,14 +169,29 @@ multiplexerProtocolReceive worker = receiveThreadLoop multiplexerDecoder
         Just channel -> handleChannelMessage channel headers len
         Nothing -> liftIO $ reportProtocolError worker ("Received message on invalid channel: " <> show workerState.receiveChannel)
     handleMultiplexerMessage (SwitchChannel channelId) = liftIO $ modifyMVar_ worker.stateMVar $ \state -> pure state{receiveChannel=channelId}
+    handleMultiplexerMessage CloseChannel = liftIO $ do
+      workerState <- readMVar worker.stateMVar
+      case HM.lookup workerState.receiveChannel workerState.channels of
+        Just channel -> channelConfirmClose channel
+        Nothing -> reportProtocolError worker ("Received CloseChannel on invalid channel: " <> show workerState.receiveChannel)
     handleMultiplexerMessage x = liftIO $ print x >> undefined -- Unhandled multiplexer message
 
     handleChannelMessage :: Channel -> [MultiplexerProtocolMessageHeader] -> MessageLength -> StateT BS.ByteString IO ()
     handleChannelMessage channel headers len = do
-      headerResults <- liftIO $ sequence (processHeader <$> headers)
-      decoder <- liftIO $ channelStartHandleMessage channel headerResults
+      decoder <- liftIO $ do
+        -- Don't receive messages on closed channels
+        channelState <- readMVar channel.stateMVar
+        case channelState.connectionState of
+          Connected -> do
+            headerResults <- sequence (processHeader <$> headers)
+            channelStartHandleMessage channel headerResults
+          -- The channel is closed but the remote might not know that yet, so the message is silently ignored
+          Closed -> closedChannelMessageHandler <$> sequence (processHeader <$> headers)
+          -- This might only be reached in some edge cases, as a closed channel will be removed from the channel map after the close is confirmed.
+          CloseConfirmed -> reportProtocolError worker ("Received message on channel " <> show channel.channelId <> " after receiving a close confirmation for that channel")
+
       -- StateT currently contains leftovers
-      initialLeftovers <- get
+      initialLeftovers <- State.get
       let
         leftoversLength = fromIntegral $ BS.length initialLeftovers
         remaining = len - leftoversLength
@@ -132,7 +199,7 @@ multiplexerProtocolReceive worker = receiveThreadLoop multiplexerDecoder
       (channelCallback, leftovers) <- liftIO $ runDecoder remaining (pushChunk decoder initialLeftovers)
 
       -- Data is received in chunks but messages have a defined length, so leftovers are put back into StateT
-      put leftovers
+      State.put leftovers
       -- Critical section: don't interrupt downstream callbacks
       liftIO $ withMVar worker.killReceiverMVar $ const channelCallback
       where
@@ -161,44 +228,145 @@ multiplexerProtocolReceive worker = receiveThreadLoop multiplexerDecoder
         failedToTerminate :: IO a
         failedToTerminate = reportLocalError worker ("Decoder on channel " <> show channel.channelId <> " failed to terminate after end-of-input")
         processHeader :: MultiplexerProtocolMessageHeader -> IO MessageHeaderResult
-        processHeader CreateChannel = undefined
+        processHeader CreateChannel = do
+          channelId <- modifyMVar worker.stateMVar $ \workerState -> do
+            let
+              receiveNextChannelId = workerState.receiveNextChannelId
+              newWorkerState = workerState{receiveNextChannelId = receiveNextChannelId + 2}
+            pure (newWorkerState, receiveNextChannelId)
+          modifyMVar channel.stateMVar $ \state -> do
+            createdChannel <- newSubChannel channel.worker channelId channel
+            let newState = state{
+              children = createdChannel : state.children
+            }
+            pure (newState, CreateChannelHeaderResult createdChannel)
     receiveThrowing :: IO BS.ByteString
     receiveThrowing = do
       state <- readMVar worker.stateMVar
       maybe (throwIO NotConnected) (.receive) state.socketConnection
 
 
+closedChannelMessageHandler :: [MessageHeaderResult] -> Decoder (IO ())
+closedChannelMessageHandler headers = discardMessageDecoder $ mapM_ handleHeader headers
+  where
+    handleHeader :: MessageHeaderResult -> IO ()
+    handleHeader (CreateChannelHeaderResult createdChannel) =
+      -- The channel that received the message is already closed, so newly created children are implicitly closed as well
+      modifyMVar_ createdChannel.stateMVar $ \state ->
+        pure state{connectionState = Closed}
+
+    discardMessageDecoder :: IO () -> Decoder (IO ())
+    discardMessageDecoder action = Partial (maybe done partial)
+      where
+        partial :: BS.ByteString -> Decoder (IO ())
+        partial = const (discardMessageDecoder action)
+        done :: Decoder (IO ())
+        done = Done "" 0 action
+
+withMultiplexerState :: MultiplexerProtocolWorker -> StateT MultiplexerProtocolWorkerState IO a -> IO a
+withMultiplexerState worker action = modifyMVar worker.stateMVar $ fmap swap . runStateT action
+
 multiplexerSend :: MultiplexerProtocolWorker -> MultiplexerProtocolMessage -> IO ()
-multiplexerSend worker msg = withMVar worker.stateMVar $ \state -> multiplexerStateSend state msg
+multiplexerSend worker msg = withMultiplexerState worker (multiplexerStateSend msg)
 
-multiplexerStateSend :: MultiplexerProtocolWorkerState -> MultiplexerProtocolMessage -> IO ()
-multiplexerStateSend state = multiplexerStateSendRaw state . encode
+multiplexerStateSend :: MultiplexerProtocolMessage -> StateT MultiplexerProtocolWorkerState IO ()
+multiplexerStateSend = multiplexerStateSendRaw . encode
 
-multiplexerStateSendRaw :: MultiplexerProtocolWorkerState -> BSL.ByteString -> IO ()
-multiplexerStateSendRaw MultiplexerProtocolWorkerState{socketConnection=Just connection} rawMsg = connection.send rawMsg
-multiplexerStateSendRaw MultiplexerProtocolWorkerState{socketConnection=Nothing} _ = throwIO NotConnected
+multiplexerStateSendRaw :: BSL.ByteString -> StateT MultiplexerProtocolWorkerState IO ()
+multiplexerStateSendRaw rawMsg = do
+  state <- State.get
+  case state.socketConnection of
+    Nothing -> liftIO $ throwIO NotConnected
+    Just connection -> liftIO $ connection.send rawMsg
 
-multiplexerSendChannelMessage :: MultiplexerProtocolWorker -> ChannelId -> BSL.ByteString -> [MessageHeader] -> IO ()
-multiplexerSendChannelMessage worker channelId msg headers = do
+multiplexerSendChannelMessage :: Channel -> BSL.ByteString -> [MessageHeader] -> IO ()
+multiplexerSendChannelMessage channel msg headers = do
   -- Sending a channel message consists of multiple low-level send operations, so the MVar is held during the operation
-  modifyMVar_ worker.stateMVar $ \state -> do
-    -- Switch to the specified channel (if required)
-    when (state.sendChannel /= channelId) $ multiplexerSend worker (SwitchChannel channelId)
+  withMultiplexerState worker $ do
+    multiplexerSwitchChannel channel.channelId
 
     headerMessages <- sequence (prepareHeader <$> headers)
-    multiplexerStateSend state (ChannelMessage headerMessages (fromIntegral (BSL.length msg)))
-    multiplexerStateSendRaw state msg
-    pure state{sendChannel=channelId}
-  where
-    prepareHeader :: MessageHeader -> IO MultiplexerProtocolMessageHeader
-    prepareHeader (CreateChannelHeader _newChannelCallback) = undefined
 
-
-multiplexerChannelClose :: MultiplexerProtocolWorker -> ChannelId -> IO ()
-multiplexerChannelClose worker channelId =
-  if channelId == 0
-    then multiplexerClose worker
-    else undefined
+    multiplexerStateSend (ChannelMessage headerMessages (fromIntegral (BSL.length msg)))
+    multiplexerStateSendRaw msg
+  where
+    worker :: MultiplexerProtocolWorker
+    worker = channel.worker
+    prepareHeader :: MessageHeader -> StateT MultiplexerProtocolWorkerState IO MultiplexerProtocolMessageHeader
+    prepareHeader (CreateChannelHeader newChannelCallback) = do
+      nextChannelId <- State.state (\state -> (state.sendNextChannelId, state{sendNextChannelId = state.sendNextChannelId + 1}))
+      createdChannel <- liftIO $ newSubChannel worker nextChannelId channel
+
+      -- TODO we probably don't want to call the callback here, as the state is locked; we also don't want to call it later, because at that point messages could already arrive and the handler has to be set
+      -- TODO: also we are currently holding the MultiplexerProtocolWorkerState which means sending messages from the callback will result in a deadlock - calling code must not do that. That's also an indication for a bad design
+      -- TODO the current design requires the caller to use mvars/iorefs to get the created channel - also not optimal.
+      liftIO $ newChannelCallback createdChannel
+      pure CreateChannel
+
+multiplexerSwitchChannel :: ChannelId -> StateT MultiplexerProtocolWorkerState IO ()
+multiplexerSwitchChannel channelId = do
+  -- Check if channel switch is required and update current channel
+  shouldSwitchChannel <- State.state (\state -> (state.sendChannel /= channelId, state{sendChannel = channelId}))
+  when shouldSwitchChannel $ multiplexerStateSend (SwitchChannel channelId)
+
+-- | Closes a channel and all it's children: After the function completes, the channel callback will no longer be called on received messages and sending messages on the channel will fail.
+-- Calling close on a closed channel is a noop.
+channelClose :: Channel -> IO ()
+channelClose channel = do
+  -- Change channel state of all unclosed channels in the tree to closed
+  channelWasClosed <- channelClose' channel
+
+  when channelWasClosed $ do
+    -- Send close message
+    withMultiplexerState channel.worker $ do
+      multiplexerSwitchChannel channel.channelId
+      multiplexerStateSend CloseChannel
+
+    -- Terminate the worker when the root channel is closed
+    when (channel.channelId == 0) $ multiplexerClose channel.worker
+  where
+    channelClose' :: Channel -> IO Bool
+    channelClose' chan = modifyMVar chan.stateMVar $ \state ->
+      case state.connectionState of
+        Connected -> do
+          -- Close all children while blocking the state. This prevents children from receiving a messages after the parent channel has already rejected a message
+          liftIO (mapM_ (void . channelClose') state.children)
+          pure (state{connectionState = Closed}, True)
+        -- Channel was already closed and can be ignored
+        Closed -> pure (state, False)
+        CloseConfirmed -> pure (state, False)
+
+-- Called on a channel when a ChannelClose message is received
+channelConfirmClose :: Channel -> IO ()
+channelConfirmClose channel = do
+  closeConfirmedIds <- channelClose' channel
+
+  -- List can only be empty when the channel was already confirmed as closed
+  unless (closeConfirmedIds == []) $ do
+    -- Remote channels from worker
+    withMultiplexerState channel.worker $ do
+      State.modify $ \state -> state{channels = foldr HM.delete state.channels closeConfirmedIds}
+
+    -- Terminate the worker when the root channel is closed
+    when (channel.channelId == 0) $ multiplexerClose channel.worker
+  where
+    channelClose' :: Channel -> IO [ChannelId]
+    channelClose' chan = modifyMVar chan.stateMVar $ \state ->
+      case state.connectionState of
+        Connected -> do
+          closedIdLists <- liftIO (mapM channelClose' state.children)
+          let
+            closedIds = chan.channelId : mconcat closedIdLists
+            newState = state{connectionState = CloseConfirmed}
+          pure (newState, closedIds)
+        Closed -> do
+          closedIdLists <- liftIO (mapM channelClose' state.children)
+          let
+            closedIds = chan.channelId : mconcat closedIdLists
+            newState = state{connectionState = CloseConfirmed}
+          pure (newState, closedIds)
+        -- Ignore already closed children
+        CloseConfirmed -> pure (state, [])
 
 -- | Close a mulxiplexer worker by closing the connection it is based on and then stopping the worker thread.
 multiplexerClose :: MultiplexerProtocolWorker -> IO ()
@@ -206,6 +374,7 @@ multiplexerClose worker = do
   multiplexerConnectionClose worker
   modifyMVar_ worker.killReceiverMVar $ \killReceiver -> do
     killReceiver
+    -- Replace 'killReceiver'-action with a no-op to ensure it runs only once
     pure (pure ())
 
 -- | Internal close operation: Closes the communication channel a multiplexer is operating on. The caller has the responsibility to ensure the receiver thread is closed.
@@ -221,9 +390,7 @@ multiplexerConnectionClose worker = do
 reportProtocolError :: HasMultiplexerProtocolWorker a => a -> String -> IO b
 reportProtocolError hasWorker message = do
   let worker = getMultiplexerProtocolWorker hasWorker
-  modifyMVar_ worker.stateMVar $ \state -> do
-    multiplexerStateSend state $ ProtocolError message
-    pure state
+  multiplexerSend worker $ ProtocolError message
   -- TODO custom error type, close connection
   undefined
 
@@ -231,31 +398,10 @@ reportLocalError :: HasMultiplexerProtocolWorker a => a -> String -> IO b
 reportLocalError hasWorker message = do
   hPutStrLn stderr message
   let worker = getMultiplexerProtocolWorker hasWorker
-  modifyMVar_ worker.stateMVar $ \state -> do
-    multiplexerStateSend state $ ProtocolError "Internal server error"
-    pure state
+  multiplexerSend worker $ ProtocolError "Internal server error"
   -- TODO custom error type, close connection
   undefined
 
-data Channel = Channel {
-  channelId :: ChannelId,
-  worker :: MultiplexerProtocolWorker,
-  sendStateMVar :: MVar ChannelSendState,
-  receiveStateMVar :: MVar ChannelReceiveState
-}
-instance HasMultiplexerProtocolWorker Channel where
-  getMultiplexerProtocolWorker = (.worker)
-newtype ChannelSendState = ChannelSendState {
-  nextMessageId :: MessageId
-}
-data ChannelReceiveState = ChannelReceiveState {
-  nextMessageId :: MessageId,
-  handler :: ChannelMessageHandler
-}
-
-type ChannelMessageHandler = MessageId -> [MessageHeaderResult] -> Decoder (IO ())
-type SimpleChannelMessageHandler = MessageId -> [MessageHeaderResult] -> BSL.ByteString -> IO ()
-
 simpleMessageHandler :: SimpleChannelMessageHandler -> ChannelMessageHandler
 simpleMessageHandler handler msgId headers = decoder ""
   where
@@ -267,10 +413,23 @@ simpleMessageHandler handler msgId headers = decoder ""
         done :: Decoder (IO ())
         done = Done "" (BSL.length acc) (handler msgId headers acc)
 
-
--- Should not be exported
-newChannel :: MultiplexerProtocolWorker -> ChannelId -> IO Channel
-newChannel worker channelId = do
+newSubChannel :: MultiplexerProtocolWorker -> ChannelId -> Channel -> IO Channel
+newSubChannel worker channelId parent =
+  modifyMVar parent.stateMVar $ \parentChannelState -> do
+    -- Holding the parents channelState while initializing the channel will ensure the ChannelConnectivity is inherited atomically
+    createdChannel <- newChannel worker channelId parentChannelState.connectionState
+
+    let newParentState = parentChannelState{
+      children = createdChannel : parentChannelState.children
+    }
+    pure (newParentState, createdChannel)
+
+newChannel :: MultiplexerProtocolWorker -> ChannelId -> ChannelConnectivity -> IO Channel
+newChannel worker channelId connectionState = do
+  stateMVar <- newMVar ChannelState {
+    connectionState,
+    children = []
+  }
   sendStateMVar <- newMVar ChannelSendState {
     nextMessageId = 0
   }
@@ -282,25 +441,32 @@ newChannel worker channelId = do
   let channel = Channel {
     worker,
     channelId,
+    stateMVar,
     sendStateMVar,
     receiveStateMVar
   }
   modifyMVar_ worker.stateMVar $ \state -> pure state{channels = HM.insert channelId channel state.channels}
   pure channel
-channelSend :: Channel -> BSL.ByteString -> [MessageHeader] -> (MessageId -> IO ()) -> IO ()
+
+channelSend :: Channel -> [MessageHeader] -> BSL.ByteString -> (MessageId -> IO ()) -> IO ()
 channelSend channel msg headers callback = do
   modifyMVar_ channel.sendStateMVar $ \state -> do
+    -- Don't send on closed channels
+    channelState <- readMVar channel.stateMVar
+    unless (channelState.connectionState == Connected) $ throwIO ChannelNotConnected
+
     callback state.nextMessageId
-    multiplexerSendChannelMessage channel.worker channel.channelId msg headers
+    multiplexerSendChannelMessage channel headers msg
     pure state{nextMessageId = state.nextMessageId + 1}
-channelSend_ :: Channel -> BSL.ByteString -> [MessageHeader] -> IO ()
-channelSend_ channel msg headers = channelSend channel msg headers (const (pure ()))
-channelClose :: Channel -> IO ()
-channelClose channel = multiplexerChannelClose channel.worker channel.channelId
+
+channelSend_ :: Channel -> [MessageHeader] -> BSL.ByteString -> IO ()
+channelSend_ channel headers msg = channelSend channel headers msg (const (pure ()))
+
 channelStartHandleMessage :: Channel -> [MessageHeaderResult] -> IO (Decoder (IO ()))
 channelStartHandleMessage channel headers = do
   (msgId, handler) <- modifyMVar channel.receiveStateMVar $ \state ->
     pure (state{nextMessageId = state.nextMessageId + 1}, (state.nextMessageId, state.handler))
   pure (handler msgId headers)
+
 channelSetHandler :: Channel -> ChannelMessageHandler -> IO ()
 channelSetHandler channel handler = modifyMVar_ channel.receiveStateMVar $ \state -> pure state{handler}
diff --git a/test/Network/Rpc/MultiplexerSpec.hs b/test/Network/Rpc/MultiplexerSpec.hs
new file mode 100644
index 0000000..95dced1
--- /dev/null
+++ b/test/Network/Rpc/MultiplexerSpec.hs
@@ -0,0 +1,56 @@
+module Network.Rpc.MultiplexerSpec where
+
+import Control.Concurrent.Async (concurrently_)
+import Control.Concurrent.MVar
+import Control.Exception (bracket, mask_)
+import Prelude
+import Network.Rpc.Multiplexer
+import Network.Rpc.Connection
+import Test.Hspec
+
+spec :: Spec
+spec = describe "runMultiplexerProtocol" $ parallel $ do
+  it "can be closed from the channelSetupHook" $ do
+    (x, _) <- newDummySocketPair
+    runMultiplexer channelClose x
+
+  it "fails when run in masked state" $ do
+    (x, _) <- newDummySocketPair
+    mask_ $ runMultiplexer channelClose x `shouldThrow` anyException
+
+  it "closes when the remote is closed" $ do
+    (x, y) <- newDummySocketPair
+    concurrently_
+      (runMultiplexer (const (pure ())) x)
+      (runMultiplexer channelClose y)
+
+  it "it can send and receive simple messages" $ do
+    withEchoServer $ \channel -> do
+      recvMVar <- newEmptyMVar
+      channelSetHandler channel $ simpleMessageHandler $ \_ _ -> putMVar recvMVar
+      channelSend_ channel [] "foobar"
+      takeMVar recvMVar `shouldReturn` "foobar"
+      channelSend_ channel [] "test"
+      takeMVar recvMVar `shouldReturn` "test"
+
+
+withEchoServer :: (Channel -> IO a) -> IO a
+withEchoServer fn = bracket setup close (\(channel, _) -> fn channel)
+  where
+    setup :: IO (Channel, Channel)
+    setup = do
+      (x, y) <- newDummySocketPair
+      echoChannel <- newMultiplexer y
+      configureEchoHandler echoChannel
+      mainChannel <- newMultiplexer x
+      pure (mainChannel, echoChannel)
+    close :: (Channel, Channel) -> IO ()
+    close (x, y) = channelClose x >> channelClose y
+    configureEchoHandler :: Channel -> IO ()
+    configureEchoHandler channel = channelSetHandler channel (echoHandler channel)
+    echoHandler :: Channel -> ChannelMessageHandler
+    echoHandler channel = simpleMessageHandler $ \_msgId headers msg -> do
+      mapM_ echoHeaderHandler headers
+      channelSend_ channel [] msg
+    echoHeaderHandler :: MessageHeaderResult -> IO ()
+    echoHeaderHandler (CreateChannelHeaderResult channel) = configureEchoHandler channel
-- 
GitLab