diff --git a/src/Network/Rpc/Multiplexer.hs b/src/Network/Rpc/Multiplexer.hs index d53693a5a42577ea41efc3f7ebd50189d683f576..21ccd3637cf0c55db6e10db9a08f6867f4a4714b 100644 --- a/src/Network/Rpc/Multiplexer.hs +++ b/src/Network/Rpc/Multiplexer.hs @@ -24,8 +24,8 @@ module Network.Rpc.Multiplexer ( import Control.Concurrent.Async (async, link) import Control.Concurrent (myThreadId, throwTo) -import Control.Exception (Exception(..), MaskingState(Unmasked), catch, finally, interruptible, throwIO, getMaskingState, mask_) -import Control.Monad (when, unless, void) +import Control.Exception (Exception(..), MaskingState(Unmasked), catch, handle, finally, interruptible, throwIO, getMaskingState, mask_) +import Control.Monad (when, unless) import Control.Monad.IO.Class (liftIO) import Control.Monad.State (StateT, execStateT, runStateT, lift) import qualified Control.Monad.State as State @@ -384,21 +384,23 @@ channelClose channel = do -- Change channel state of all unclosed channels in the tree to closed channelWasClosed <- channelClose' channel - when channelWasClosed $ do - -- Send close message - withMultiplexerState channel.worker $ do - multiplexerSwitchChannel channel.channelId - multiplexerStateSend CloseChannel + when channelWasClosed $ + -- Closing a channel on a Connection that is no longer connected should not throw an exception (channelClose is a resource management operation and is supposed to be idempotent) + handle (\(_ :: NotConnected) -> pure ()) $ do + -- Send close message + withMultiplexerState channel.worker $ do + multiplexerSwitchChannel channel.channelId + multiplexerStateSend CloseChannel - -- Terminate the worker when the root channel is closed - when (channel.channelId == 0) $ multiplexerClose channel.worker + -- Terminate the worker when the root channel is closed + when (channel.channelId == 0) $ multiplexerClose channel.worker where channelClose' :: Channel -> IO Bool channelClose' chan = modifyMVar chan.stateMVar $ \state -> case state.connectionState of Connected -> do -- Close all children while blocking the state. This prevents children from receiving a messages after the parent channel has already rejected a message - liftIO (mapM_ (void . channelClose') state.children) + liftIO (mapM_ channelClose' state.children) pure (state{connectionState = Closed}, True) -- Channel was already closed and can be ignored Closed -> pure (state, False) @@ -410,7 +412,7 @@ channelConfirmClose channel = do closeConfirmedIds <- channelClose' channel -- List can only be empty when the channel was already confirmed as closed - unless (closeConfirmedIds == []) $ do + unless (null closeConfirmedIds) $ do -- Remote channels from worker withMultiplexerState channel.worker $ do State.modify $ \state -> state{channels = foldr HM.delete state.channels closeConfirmedIds} diff --git a/test/Network/Rpc/MultiplexerSpec.hs b/test/Network/Rpc/MultiplexerSpec.hs index 8c5d3257c0b0279852c369b64614577bf1f648e4..0efdca1e925b30fbd272dde545019f74e36538cc 100644 --- a/test/Network/Rpc/MultiplexerSpec.hs +++ b/test/Network/Rpc/MultiplexerSpec.hs @@ -28,7 +28,7 @@ spec = describe "runMultiplexerProtocol" $ parallel $ do it "it can send and receive simple messages" $ do recvMVar <- newEmptyMVar withEchoServer $ \channel -> do - channelSetHandler channel $ ((\_ -> putMVar recvMVar) :: ReceivedMessageResources -> BSL.ByteString -> IO ()) + channelSetHandler channel ((\_ -> putMVar recvMVar) :: ReceivedMessageResources -> BSL.ByteString -> IO ()) channelSendSimple channel "foobar" takeMVar recvMVar `shouldReturn` "foobar" channelSendSimple channel "test" @@ -39,7 +39,7 @@ spec = describe "runMultiplexerProtocol" $ parallel $ do it "it can create sub-channels" $ do recvMVar <- newEmptyMVar withEchoServer $ \channel -> do - channelSetHandler channel $ ((\_ -> putMVar recvMVar) :: ReceivedMessageResources -> BSL.ByteString -> IO ()) + channelSetHandler channel ((\_ -> putMVar recvMVar) :: ReceivedMessageResources -> BSL.ByteString -> IO ()) SentMessageResources{createdChannels=[_]} <- channelSend_ channel [CreateChannelHeader] "create a channel" takeMVar recvMVar `shouldReturn` "create a channel" SentMessageResources{createdChannels=[_, _, _]} <- channelSend_ channel [CreateChannelHeader, CreateChannelHeader, CreateChannelHeader] "create more channels" @@ -58,8 +58,8 @@ spec = describe "runMultiplexerProtocol" $ parallel $ do SentMessageResources{createdChannels=[c1, c2]} <- channelSend_ channel [CreateChannelHeader, CreateChannelHeader] "create channels" takeMVar recvMVar `shouldReturn` "create channels" - channelSetHandler c1 $ ((\_ -> putMVar c1RecvMVar) :: ReceivedMessageResources -> BSL.ByteString -> IO ()) - channelSetHandler c2 $ ((\_ -> putMVar c2RecvMVar) :: ReceivedMessageResources -> BSL.ByteString -> IO ()) + channelSetHandler c1 ((\_ -> putMVar c1RecvMVar) :: ReceivedMessageResources -> BSL.ByteString -> IO ()) + channelSetHandler c2 ((\_ -> putMVar c2RecvMVar) :: ReceivedMessageResources -> BSL.ByteString -> IO ()) channelSendSimple c1 "test" takeMVar c1RecvMVar `shouldReturn` "test" @@ -72,7 +72,7 @@ spec = describe "runMultiplexerProtocol" $ parallel $ do SentMessageResources{createdChannels=[c3]} <- channelSend_ channel [CreateChannelHeader] "create another channel" takeMVar recvMVar `shouldReturn` "create another channel" - channelSetHandler c3 $ ((\_ -> putMVar c3RecvMVar) :: ReceivedMessageResources -> BSL.ByteString -> IO ()) + channelSetHandler c3 ((\_ -> putMVar c3RecvMVar) :: ReceivedMessageResources -> BSL.ByteString -> IO ()) channelSendSimple c3 "test5" takeMVar c3RecvMVar `shouldReturn` "test5"