From fab29b9c79391168e558d8cea2d55aa2a5f92030 Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Tue, 5 Apr 2022 22:17:23 +0200
Subject: [PATCH] Simplify multiplexer outbox and add unsafeQueueChannelMessage

---
 src/Quasar/Network/Multiplexer.hs | 110 ++++++++++++++++--------------
 1 file changed, 60 insertions(+), 50 deletions(-)

diff --git a/src/Quasar/Network/Multiplexer.hs b/src/Quasar/Network/Multiplexer.hs
index 28f379d..7c6ee92 100644
--- a/src/Quasar/Network/Multiplexer.hs
+++ b/src/Quasar/Network/Multiplexer.hs
@@ -11,6 +11,7 @@ module Quasar.Network.Multiplexer (
   defaultMessageConfiguration,
   channelSend,
   sendChannelMessage,
+  unsafeQueueChannelMessage,
   channelSend_,
   channelSendSimple,
 
@@ -85,7 +86,9 @@ data MultiplexerMessage
 
 -- ** Multiplexer
 
-type OutboxMessage = (ChannelId, NewChannelCount, BSL.ByteString)
+data OutboxMessage
+  = OutboxSendMessage ChannelId NewChannelCount BSL.ByteString
+  | OutboxCloseChannel ChannelId
 
 data Multiplexer = Multiplexer {
   side :: MultiplexerSide,
@@ -93,11 +96,10 @@ data Multiplexer = Multiplexer {
   multiplexerException :: Promise MultiplexerException,
   multiplexerResult :: Promise (Maybe MultiplexerException),
   receiveThreadCompleted :: Future (),
+  -- Set to true after magic bytes have been received
   receivedHeader :: TVar Bool,
-  outbox :: TMVar OutboxMessage,
+  outbox :: TVar [OutboxMessage],
   outboxGuard :: MVar (),
-  -- Set to true after magic bytes have been received
-  closeChannelOutbox :: TVar [ChannelId],
   channelsVar :: TVar (HM.HashMap ChannelId Channel),
   nextReceiveChannelId :: TVar ChannelId,
   nextSendChannelId :: TVar ChannelId
@@ -229,7 +231,7 @@ newChannelSTM parent@Channel{multiplexer, quasar=parentQuasar} channelId = do
 sendChannelCloseMessage :: Channel -> STM ()
 sendChannelCloseMessage channel = do
   unlessM (readTVar channel.sentCloseMessage) do
-    modifyTVar channel.multiplexer.closeChannelOutbox (channel.channelId :)
+    modifyTVar channel.multiplexer.outbox (OutboxCloseChannel channel.channelId :)
     -- Mark as closed and propagate close state to children
     markAsClosed channel
     cleanupChannel channel
@@ -330,12 +332,11 @@ newMultiplexerInternal side connection = disposeOnError do
   -- without accidentally disposing external resources.
   resourceManager <- askResourceManager
 
-  outbox <- liftIO $ newEmptyTMVarIO
+  outbox <- liftIO $ newTVarIO []
   multiplexerException <- newPromise
   multiplexerResult <- newPromise
   outboxGuard <- liftIO $ newMVar ()
   receivedHeader <- liftIO $ newTVarIO False
-  closeChannelOutbox <- liftIO $ newTVarIO mempty
   nextReceiveChannelId <- liftIO $ newTVarIO $ if side == MultiplexerSideA then 2 else 1
   nextSendChannelId <- liftIO $ newTVarIO $ if side == MultiplexerSideA then 1 else 2
 
@@ -385,7 +386,6 @@ newMultiplexerInternal side connection = disposeOnError do
         outbox,
         outboxGuard,
         receivedHeader,
-        closeChannelOutbox,
         channelsVar,
         nextReceiveChannelId,
         nextSendChannelId
@@ -440,17 +440,15 @@ sendThread multiplexer sendFn = do
           -- Send exception (if required for that exception type) and then terminate send loop
           Just fatalException -> pure $ sendException fatalException
           Nothing -> do
-            mMessage <- tryTakeTMVar multiplexer.outbox
-            closeChannelQueue <- swapTVar multiplexer.closeChannelOutbox []
-            case (mMessage, closeChannelQueue) of
+            messages <- swapTVar multiplexer.outbox []
+            case messages of
               -- Exit when the receive thread has stopped and there is no error and no message left to send
-              (Nothing, []) -> pure () <$ awaitSTM multiplexer.receiveThreadCompleted
+              [] -> pure () <$ awaitSTM multiplexer.receiveThreadCompleted
               _ -> pure do
-                msg <- execWriterT do
-                  mapM_ formatChannelMessage mMessage
-                  -- closeChannelQueue is used as a queue, so it has to be reversed to keep the order of close messages
-                  formatCloseMessages (reverse closeChannelQueue)
-                liftIO $ send msg
+                bs <- execWriterT do
+                  -- outbox is a list that is used as a queue, so it has to be reversed to preserve the correct order
+                  mapM_ formatMessage (reverse messages)
+                liftIO $ send bs
                 sendLoop
     send :: MonadIO m => Put -> m ()
     send chunks = liftIO $ sendFn (Binary.runPut chunks) `catchAll` (throwM . ConnectionLost . SendFailed)
@@ -467,20 +465,17 @@ sendThread multiplexer sendFn = do
         tell $ Binary.put $ ChannelProtocolError message
       send msg
     sendException (ReceivedChannelProtocolException _ _) = pure ()
-    formatChannelMessage :: (ChannelId, NewChannelCount, BSL.ByteString) -> WriterT Put (StateT ChannelId IO) ()
-    formatChannelMessage (channelId, newChannelCount, message) = do
+    formatMessage :: OutboxMessage -> WriterT Put (StateT ChannelId IO) ()
+    formatMessage (OutboxSendMessage channelId newChannelCount message) = do
       switchToChannel channelId
       tell do
         Binary.put (ChannelMessage newChannelCount messageLength)
         Binary.putLazyByteString message
       where
         messageLength = fromIntegral $ BSL.length message
-    formatCloseMessages :: [ChannelId] -> WriterT Put (StateT ChannelId IO) ()
-    formatCloseMessages [] = pure mempty
-    formatCloseMessages (i:is) = do
-      switchToChannel i
+    formatMessage (OutboxCloseChannel channelId) = do
+      switchToChannel channelId
       tell $ Binary.put CloseChannel
-      formatCloseMessages is
     switchToChannel :: ChannelId -> WriterT Put (StateT ChannelId IO) ()
     switchToChannel channelId = do
       currentChannelId <- State.get
@@ -638,40 +633,55 @@ channelSend = sendChannelMessage
 {-# DEPRECATED channelSend "Use sendChannelMessage instead" #-}
 
 sendChannelMessage :: MonadIO m => Channel -> MessageConfiguration -> BSL.ByteString -> (MessageId -> STM ()) -> m SentMessageResources
-sendChannelMessage channel@Channel{multiplexer} MessageConfiguration{closeChannel, createChannels} payload messageIdHook = liftIO do
-  -- NOTE At most one message can be queued per STM transaction, so `sendChannelMessage` cannot be changed to STM
-
+sendChannelMessage channel@Channel{multiplexer} messageConfiguration payload messageIdHook = liftIO do
   -- Locking the 'outboxGuard' guarantees fairness when sending messages concurrently (it also prevents unnecessary
   -- STM retries)
-  withMVar multiplexer.outboxGuard \_ -> do
-    atomically do
-      -- Abort if the multiplexer is finished or currently cleaning up
-      mapM_ throwM =<< peekFutureSTM (toFuture multiplexer.multiplexerException)
+  withMVar multiplexer.outboxGuard \_ ->
+    atomically $ sendChannelMessageInternal BlockUntilReady channel messageConfiguration payload messageIdHook
 
-      -- Abort if the channel is closed
-      verifyChannelIsConnected channel
+-- | Unsafely queue a network message to an unbounded send queue. This function does not block, even if `sendChannelMessage` would block. Queued messages will cause concurrent or following `sendChannelMessage`-calls to block until the queue is flushed.
+unsafeQueueChannelMessage :: Channel -> MessageConfiguration -> BSL.ByteString -> (MessageId -> STM ()) -> STM SentMessageResources
+unsafeQueueChannelMessage = sendChannelMessageInternal UnboundedQueue
 
-      -- Block until all previously queued close messages have been sent.
-      -- This prevents message reordering in the send thread.
-      check . null =<< readTVar multiplexer.closeChannelOutbox
 
-      -- Put the message into the outbox. It will be picked up by the send thread.
-      -- Retries (blocks) until the outbox is available.
-      putTMVar multiplexer.outbox (channel.channelId, createChannels, payload)
-      messageId <- stateTVar channel.nextSendMessageId (\x -> (x, x + 1))
-      messageIdHook messageId
+data QueueBehavior = BlockUntilReady | UnboundedQueue
 
-      when closeChannel do
-        sendChannelCloseMessage channel
-        disposeEventually_ channel
+sendChannelMessageInternal :: QueueBehavior -> Channel -> MessageConfiguration -> BSL.ByteString -> (MessageId -> STM ()) -> STM SentMessageResources
+sendChannelMessageInternal queueBehavior channel@Channel{multiplexer} MessageConfiguration{closeChannel, createChannels} payload messageIdHook = do
+  -- NOTE At most one message can be queued per STM transaction, so `sendChannelMessage` cannot be changed to STM
 
-      createdChannelIds <- stateTVar multiplexer.nextSendChannelId (createChannelIds createChannels)
-      createdChannels <- mapM (newChannelSTM channel) createdChannelIds
+  -- Abort if the multiplexer is finished or currently cleaning up
+  mapM_ throwM =<< peekFutureSTM (toFuture multiplexer.multiplexerException)
 
-      pure SentMessageResources {
-        messageId,
-        createdChannels
-      }
+  -- Abort if the channel is closed
+  verifyChannelIsConnected channel
+
+  msgs <- readTVar multiplexer.outbox
+
+  case queueBehavior of
+    BlockUntilReady ->
+      -- Block until all previously queued messages have been sent.
+      check $ null msgs
+    UnboundedQueue -> pure ()
+
+  -- Put the message into the outbox. It will be picked up by the send thread.
+  let msg = OutboxSendMessage channel.channelId createChannels payload
+  writeTVar multiplexer.outbox (msg:msgs)
+
+  messageId <- stateTVar channel.nextSendMessageId (\x -> (x, x + 1))
+  messageIdHook messageId
+
+  when closeChannel do
+    sendChannelCloseMessage channel
+    disposeEventually_ channel
+
+  createdChannelIds <- stateTVar multiplexer.nextSendChannelId (createChannelIds createChannels)
+  createdChannels <- mapM (newChannelSTM channel) createdChannelIds
+
+  pure SentMessageResources {
+    messageId,
+    createdChannels
+  }
 
 createChannelIds :: NewChannelCount -> ChannelId -> ([ChannelId], ChannelId)
 createChannelIds amount firstId = (channelIds, nextId)
-- 
GitLab