From 1eaa018d271dfa2bd8fbdb62ab7a486056e7d3e5 Mon Sep 17 00:00:00 2001 From: Jens Nolte <git@queezle.net> Date: Mon, 10 May 2021 15:05:05 +0200 Subject: [PATCH] Move socket abstraction into a dedicated module --- qrpc.cabal | 1 + src/Network/Rpc.hs | 92 +++------------------------ src/Network/Rpc/Connection.hs | 112 +++++++++++++++++++++++++++++++++ src/Network/Rpc/Multiplexer.hs | 23 +------ 4 files changed, 124 insertions(+), 104 deletions(-) create mode 100644 src/Network/Rpc/Connection.hs diff --git a/qrpc.cabal b/qrpc.cabal index 9e4898f..472a53c 100644 --- a/qrpc.cabal +++ b/qrpc.cabal @@ -80,6 +80,7 @@ library import: shared-properties exposed-modules: Network.Rpc + Network.Rpc.Connection Network.Rpc.Multiplexer hs-source-dirs: src diff --git a/src/Network/Rpc.hs b/src/Network/Rpc.hs index 4c2398f..aefab16 100644 --- a/src/Network/Rpc.hs +++ b/src/Network/Rpc.hs @@ -1,9 +1,9 @@ module Network.Rpc where -import Control.Concurrent (threadDelay, forkFinally) -import Control.Concurrent.Async (Async, async, cancel, link, waitCatch, withAsync) -import Control.Exception (Exception(..), SomeException, bracket, bracketOnError, finally, throwIO, bracketOnError, onException) -import Control.Monad ((>=>), when, unless, forever, forM_) +import Control.Concurrent (forkFinally) +import Control.Concurrent.Async (Async, async, link, withAsync) +import Control.Exception (SomeException, bracket, bracketOnError, bracketOnError) +import Control.Monad ((>=>), when, forever) import Control.Monad.State (State, execState) import qualified Control.Monad.State as State import Control.Concurrent.MVar @@ -11,15 +11,15 @@ import Data.Binary (Binary, encode, decodeOrFail) import qualified Data.ByteString.Lazy as BSL import Data.Hashable (Hashable) import qualified Data.HashMap.Strict as HM -import Data.List (intercalate) import Data.Maybe (isNothing) import Language.Haskell.TH import Language.Haskell.TH.Syntax import Network.Rpc.Multiplexer +import Network.Rpc.Connection import qualified Network.Socket as Socket import Prelude -import GHC.IO (unsafeUnmask) import GHC.Generics +import GHC.IO (unsafeUnmask) import System.Posix.Files (getFileStatus, isSocket) @@ -316,64 +316,12 @@ registerChannelServerHandler protocolImpl channel = channelSetHandler channel (s -- ** Running client and server -newtype ConnectionFailed = ConnectionFailed [(Socket.AddrInfo, SomeException)] - deriving (Show) -instance Exception ConnectionFailed where - displayException (ConnectionFailed attemts) = "Connection attempts failed:\n" <> intercalate "\n" (map (\(addr, err) -> show (Socket.addrAddress addr) <> ": " <> displayException err) attemts) - withClientTCP :: RpcProtocol p => Socket.HostName -> Socket.ServiceName -> (Client p -> IO a) -> IO a withClientTCP host port = bracket (newClientTCP host port) clientClose newClientTCP :: forall p. RpcProtocol p => Socket.HostName -> Socket.ServiceName -> IO (Client p) -newClientTCP host port = do - -- 'getAddrInfo' either pures a non-empty list or throws an exception - (best:others) <- Socket.getAddrInfo (Just hints) (Just host) (Just port) - - connectTasksMVar <- newMVar [] - sockMVar <- newEmptyMVar - let - spawnConnectTask :: Socket.AddrInfo -> IO () - spawnConnectTask = \addr -> modifyMVar_ connectTasksMVar $ \old -> (:old) . (addr,) <$> connectTask addr - -- Race more connections (a missed TCP SYN will result in 3s wait before a retransmission; IPv6 might be broken) - -- Inspired by a similar implementation in browsers - raceConnections :: IO () - raceConnections = do - spawnConnectTask best - threadDelay 200000 - -- Give the "best" address another try, in case the TCP SYN gets dropped (kernel retry interval can be multiple seconds long) - spawnConnectTask best - threadDelay 100000 - -- Try to connect to all other resolved addresses to prevent waiting for e.g. a long IPv6 connection timeout - forM_ others spawnConnectTask - -- Wait for all tasks to complete, throw an exception if all connections failed - connectTasks <- readMVar connectTasksMVar - results <- mapM (\(addr, task) -> (addr,) <$> waitCatch task) connectTasks - forM_ (collect results) (throwIO . ConnectionFailed . reverse) - collect :: [(Socket.AddrInfo, Either SomeException ())] -> Maybe [(Socket.AddrInfo, SomeException)] - collect ((_, Right ()):_) = Nothing - collect ((addr, Left ex):xs) = ((addr, ex):) <$> collect xs - collect [] = Just [] - connectTask :: Socket.AddrInfo -> IO (Async ()) - connectTask addr = async $ do - sock <- connect addr - isFirst <- tryPutMVar sockMVar sock - unless isFirst $ Socket.close sock - - -- The 'raceConnections'-async is 'link'ed to this thread, so 'readMVar' is interrupted when all connection attempts fail - sock <- - (withAsync (unsafeUnmask raceConnections) (link >=> const (readMVar sockMVar)) - `finally` (mapM_ (cancel . snd) =<< readMVar connectTasksMVar)) - `onException` (mapM_ Socket.close =<< tryTakeMVar sockMVar) - -- As soon as we have an open connection, stop spawning more connections - newClient sock - where - hints :: Socket.AddrInfo - hints = Socket.defaultHints { Socket.addrFlags = [Socket.AI_ADDRCONFIG], Socket.addrSocketType = Socket.Stream } - connect :: Socket.AddrInfo -> IO Socket.Socket - connect addr = bracketOnError (openSocket addr) Socket.close $ \sock -> do - Socket.withFdSocket sock Socket.setCloseOnExecIfNeeded - Socket.connect sock $ Socket.addrAddress addr - pure sock +newClientTCP host port = newClient =<< connectTCP host port + withClientUnix :: RpcProtocol p => FilePath -> (Client p -> IO a) -> IO a withClientUnix socketPath = bracket (newClientUnix socketPath) clientClose @@ -454,6 +402,7 @@ listenOnBoundSocket protocolImpl sock = do runServerHandler :: forall p a. (RpcProtocol p, HasProtocolImpl p, IsSocketConnection a) => ProtocolImpl p -> a -> IO () runServerHandler protocolImpl = runMultiplexerProtocol (registerChannelServerHandler @p protocolImpl) . toSocketConnection + -- ** Test implementation withDummyClientServer :: forall p a. (RpcProtocol p, HasProtocolImpl p) => ProtocolImpl p -> (Client p -> IO a) -> IO a @@ -463,25 +412,9 @@ withDummyClientServer impl runClientHook = do link serverTask withClient clientSocket runClientHook -newDummySocketPair :: IO (SocketConnection, SocketConnection) -newDummySocketPair = do - upstream <- newEmptyMVar - downstream <- newEmptyMVar - let x = SocketConnection { - send=putMVar upstream . BSL.toStrict, - receive=takeMVar downstream, - close=pure () - } - let y = SocketConnection { - send=putMVar downstream . BSL.toStrict, - receive=takeMVar upstream, - close=pure () - } - pure (x, y) - -- * Internal --- + -- ** Protocol generator helpers functionArgumentTypes :: RpcFunction -> Q [Type] @@ -579,8 +512,3 @@ withAsyncLinked inner outer = withAsync inner $ \task -> link task >> outer task withAsyncLinked_ :: IO a -> IO b -> IO b withAsyncLinked_ x = withAsyncLinked x . const - - --- | Reimplementation of 'openSocket' from the 'network'-package, which got introduced in version 3.1.2.0. Should be removed later. -openSocket :: Socket.AddrInfo -> IO Socket.Socket -openSocket addr = Socket.socket (Socket.addrFamily addr) (Socket.addrSocketType addr) (Socket.addrProtocol addr) diff --git a/src/Network/Rpc/Connection.hs b/src/Network/Rpc/Connection.hs new file mode 100644 index 0000000..226ac39 --- /dev/null +++ b/src/Network/Rpc/Connection.hs @@ -0,0 +1,112 @@ +module Network.Rpc.Connection where + +import Control.Concurrent (threadDelay) +import Control.Concurrent.Async (Async, async, cancel, link, waitCatch, withAsync) +import Control.Concurrent.MVar +import Control.Exception (Exception(..), SomeException, bracketOnError, finally, throwIO, bracketOnError, onException) +import Control.Monad ((>=>), unless, forM_) +import qualified Data.ByteString as BS +import qualified Data.ByteString.Lazy as BSL +import Data.List (intercalate) +import GHC.IO (unsafeUnmask) +import qualified Network.Socket as Socket +import qualified Network.Socket.ByteString as Socket +import qualified Network.Socket.ByteString.Lazy as SocketL +import Prelude + +-- | Abstraction over a bidirectional stream connection (e.g. a socket), to be able to switch to different communication channels (e.g. stdin/stdout or a dummy implementation for unit tests). +data SocketConnection = SocketConnection { + send :: BSL.ByteString -> IO (), + receive :: IO BS.ByteString, + close :: IO () +} +class IsSocketConnection a where + toSocketConnection :: a -> SocketConnection +instance IsSocketConnection SocketConnection where + toSocketConnection = id +instance IsSocketConnection Socket.Socket where + toSocketConnection sock = SocketConnection { + send=SocketL.sendAll sock, + receive=Socket.recv sock 4096, + close=Socket.gracefulClose sock 2000 + } + + +newtype ConnectionFailed = ConnectionFailed [(Socket.AddrInfo, SomeException)] + deriving (Show) +instance Exception ConnectionFailed where + displayException (ConnectionFailed attemts) = "Connection attempts failed:\n" <> intercalate "\n" (map (\(addr, err) -> show (Socket.addrAddress addr) <> ": " <> displayException err) attemts) + +connectTCP :: Socket.HostName -> Socket.ServiceName -> IO Socket.Socket +connectTCP host port = do + -- 'getAddrInfo' either pures a non-empty list or throws an exception + (best:others) <- Socket.getAddrInfo (Just hints) (Just host) (Just port) + + connectTasksMVar <- newMVar [] + sockMVar <- newEmptyMVar + let + spawnConnectTask :: Socket.AddrInfo -> IO () + spawnConnectTask = \addr -> modifyMVar_ connectTasksMVar $ \old -> (:old) . (addr,) <$> connectTask addr + -- Race more connections (a missed TCP SYN will result in 3s wait before a retransmission; IPv6 might be broken) + -- Inspired by a similar implementation in browsers + raceConnections :: IO () + raceConnections = do + spawnConnectTask best + threadDelay 200000 + -- Give the "best" address another try, in case the TCP SYN gets dropped (kernel retry interval can be multiple seconds long) + spawnConnectTask best + threadDelay 100000 + -- Try to connect to all other resolved addresses to prevent waiting for e.g. a long IPv6 connection timeout + forM_ others spawnConnectTask + -- Wait for all tasks to complete, throw an exception if all connections failed + connectTasks <- readMVar connectTasksMVar + results <- mapM (\(addr, task) -> (addr,) <$> waitCatch task) connectTasks + forM_ (collect results) (throwIO . ConnectionFailed . reverse) + collect :: [(Socket.AddrInfo, Either SomeException ())] -> Maybe [(Socket.AddrInfo, SomeException)] + collect ((_, Right ()):_) = Nothing + collect ((addr, Left ex):xs) = ((addr, ex):) <$> collect xs + collect [] = Just [] + connectTask :: Socket.AddrInfo -> IO (Async ()) + connectTask addr = async $ do + sock <- connect addr + isFirst <- tryPutMVar sockMVar sock + unless isFirst $ Socket.close sock + + -- The 'raceConnections'-async is 'link'ed to this thread, so 'readMVar' is interrupted when all connection attempts fail + sock <- + (withAsync (unsafeUnmask raceConnections) (link >=> const (readMVar sockMVar)) + `finally` (mapM_ (cancel . snd) =<< readMVar connectTasksMVar)) + `onException` (mapM_ Socket.close =<< tryTakeMVar sockMVar) + -- As soon as we have an open connection, stop spawning more connections + pure sock + where + hints :: Socket.AddrInfo + hints = Socket.defaultHints { Socket.addrFlags = [Socket.AI_ADDRCONFIG], Socket.addrSocketType = Socket.Stream } + connect :: Socket.AddrInfo -> IO Socket.Socket + connect addr = bracketOnError (openSocket addr) Socket.close $ \sock -> do + Socket.withFdSocket sock Socket.setCloseOnExecIfNeeded + Socket.connect sock $ Socket.addrAddress addr + pure sock + + +newDummySocketPair :: IO (SocketConnection, SocketConnection) +newDummySocketPair = do + upstream <- newEmptyMVar + downstream <- newEmptyMVar + let x = SocketConnection { + send=putMVar upstream . BSL.toStrict, + receive=takeMVar downstream, + close=pure () + } + let y = SocketConnection { + send=putMVar downstream . BSL.toStrict, + receive=takeMVar upstream, + close=pure () + } + pure (x, y) + + + +-- | Reimplementation of 'openSocket' from the 'network'-package, which got introduced in version 3.1.2.0. Should be removed later. +openSocket :: Socket.AddrInfo -> IO Socket.Socket +openSocket addr = Socket.socket (Socket.addrFamily addr) (Socket.addrSocketType addr) (Socket.addrProtocol addr) diff --git a/src/Network/Rpc/Multiplexer.hs b/src/Network/Rpc/Multiplexer.hs index 88292c1..957e84b 100644 --- a/src/Network/Rpc/Multiplexer.hs +++ b/src/Network/Rpc/Multiplexer.hs @@ -1,6 +1,4 @@ module Network.Rpc.Multiplexer ( - SocketConnection(..), - IsSocketConnection(..), ChannelId, MessageId, MessageLength, @@ -34,30 +32,11 @@ import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BSL import qualified Data.HashMap.Strict as HM import Data.Word -import qualified Network.Socket as Socket -import qualified Network.Socket.ByteString as Socket -import qualified Network.Socket.ByteString.Lazy as SocketL +import Network.Rpc.Connection import Prelude import GHC.Generics import System.IO (hPutStrLn, stderr) --- | Abstraction over a socket connection, to be able to switch to different communication channels (e.g. the dummy implementation for unit tests). -data SocketConnection = SocketConnection { - send :: BSL.ByteString -> IO (), - receive :: IO BS.ByteString, - close :: IO () -} -class IsSocketConnection a where - toSocketConnection :: a -> SocketConnection -instance IsSocketConnection SocketConnection where - toSocketConnection = id -instance IsSocketConnection Socket.Socket where - toSocketConnection sock = SocketConnection { - send=SocketL.sendAll sock, - receive=Socket.recv sock 4096, - close=Socket.gracefulClose sock 2000 - } - type ChannelId = Word64 type MessageId = Word64 type MessageLength = Word64 -- GitLab