Skip to content
Snippets Groups Projects
Commit 1eaa018d authored by Jens Nolte's avatar Jens Nolte
Browse files

Move socket abstraction into a dedicated module

parent 86506e9a
No related branches found
No related tags found
No related merge requests found
......@@ -80,6 +80,7 @@ library
import: shared-properties
exposed-modules:
Network.Rpc
Network.Rpc.Connection
Network.Rpc.Multiplexer
hs-source-dirs:
src
......
module Network.Rpc where
import Control.Concurrent (threadDelay, forkFinally)
import Control.Concurrent.Async (Async, async, cancel, link, waitCatch, withAsync)
import Control.Exception (Exception(..), SomeException, bracket, bracketOnError, finally, throwIO, bracketOnError, onException)
import Control.Monad ((>=>), when, unless, forever, forM_)
import Control.Concurrent (forkFinally)
import Control.Concurrent.Async (Async, async, link, withAsync)
import Control.Exception (SomeException, bracket, bracketOnError, bracketOnError)
import Control.Monad ((>=>), when, forever)
import Control.Monad.State (State, execState)
import qualified Control.Monad.State as State
import Control.Concurrent.MVar
......@@ -11,15 +11,15 @@ import Data.Binary (Binary, encode, decodeOrFail)
import qualified Data.ByteString.Lazy as BSL
import Data.Hashable (Hashable)
import qualified Data.HashMap.Strict as HM
import Data.List (intercalate)
import Data.Maybe (isNothing)
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import Network.Rpc.Multiplexer
import Network.Rpc.Connection
import qualified Network.Socket as Socket
import Prelude
import GHC.IO (unsafeUnmask)
import GHC.Generics
import GHC.IO (unsafeUnmask)
import System.Posix.Files (getFileStatus, isSocket)
......@@ -316,64 +316,12 @@ registerChannelServerHandler protocolImpl channel = channelSetHandler channel (s
-- ** Running client and server
newtype ConnectionFailed = ConnectionFailed [(Socket.AddrInfo, SomeException)]
deriving (Show)
instance Exception ConnectionFailed where
displayException (ConnectionFailed attemts) = "Connection attempts failed:\n" <> intercalate "\n" (map (\(addr, err) -> show (Socket.addrAddress addr) <> ": " <> displayException err) attemts)
withClientTCP :: RpcProtocol p => Socket.HostName -> Socket.ServiceName -> (Client p -> IO a) -> IO a
withClientTCP host port = bracket (newClientTCP host port) clientClose
newClientTCP :: forall p. RpcProtocol p => Socket.HostName -> Socket.ServiceName -> IO (Client p)
newClientTCP host port = do
-- 'getAddrInfo' either pures a non-empty list or throws an exception
(best:others) <- Socket.getAddrInfo (Just hints) (Just host) (Just port)
connectTasksMVar <- newMVar []
sockMVar <- newEmptyMVar
let
spawnConnectTask :: Socket.AddrInfo -> IO ()
spawnConnectTask = \addr -> modifyMVar_ connectTasksMVar $ \old -> (:old) . (addr,) <$> connectTask addr
-- Race more connections (a missed TCP SYN will result in 3s wait before a retransmission; IPv6 might be broken)
-- Inspired by a similar implementation in browsers
raceConnections :: IO ()
raceConnections = do
spawnConnectTask best
threadDelay 200000
-- Give the "best" address another try, in case the TCP SYN gets dropped (kernel retry interval can be multiple seconds long)
spawnConnectTask best
threadDelay 100000
-- Try to connect to all other resolved addresses to prevent waiting for e.g. a long IPv6 connection timeout
forM_ others spawnConnectTask
-- Wait for all tasks to complete, throw an exception if all connections failed
connectTasks <- readMVar connectTasksMVar
results <- mapM (\(addr, task) -> (addr,) <$> waitCatch task) connectTasks
forM_ (collect results) (throwIO . ConnectionFailed . reverse)
collect :: [(Socket.AddrInfo, Either SomeException ())] -> Maybe [(Socket.AddrInfo, SomeException)]
collect ((_, Right ()):_) = Nothing
collect ((addr, Left ex):xs) = ((addr, ex):) <$> collect xs
collect [] = Just []
connectTask :: Socket.AddrInfo -> IO (Async ())
connectTask addr = async $ do
sock <- connect addr
isFirst <- tryPutMVar sockMVar sock
unless isFirst $ Socket.close sock
-- The 'raceConnections'-async is 'link'ed to this thread, so 'readMVar' is interrupted when all connection attempts fail
sock <-
(withAsync (unsafeUnmask raceConnections) (link >=> const (readMVar sockMVar))
`finally` (mapM_ (cancel . snd) =<< readMVar connectTasksMVar))
`onException` (mapM_ Socket.close =<< tryTakeMVar sockMVar)
-- As soon as we have an open connection, stop spawning more connections
newClient sock
where
hints :: Socket.AddrInfo
hints = Socket.defaultHints { Socket.addrFlags = [Socket.AI_ADDRCONFIG], Socket.addrSocketType = Socket.Stream }
connect :: Socket.AddrInfo -> IO Socket.Socket
connect addr = bracketOnError (openSocket addr) Socket.close $ \sock -> do
Socket.withFdSocket sock Socket.setCloseOnExecIfNeeded
Socket.connect sock $ Socket.addrAddress addr
pure sock
newClientTCP host port = newClient =<< connectTCP host port
withClientUnix :: RpcProtocol p => FilePath -> (Client p -> IO a) -> IO a
withClientUnix socketPath = bracket (newClientUnix socketPath) clientClose
......@@ -454,6 +402,7 @@ listenOnBoundSocket protocolImpl sock = do
runServerHandler :: forall p a. (RpcProtocol p, HasProtocolImpl p, IsSocketConnection a) => ProtocolImpl p -> a -> IO ()
runServerHandler protocolImpl = runMultiplexerProtocol (registerChannelServerHandler @p protocolImpl) . toSocketConnection
-- ** Test implementation
withDummyClientServer :: forall p a. (RpcProtocol p, HasProtocolImpl p) => ProtocolImpl p -> (Client p -> IO a) -> IO a
......@@ -463,25 +412,9 @@ withDummyClientServer impl runClientHook = do
link serverTask
withClient clientSocket runClientHook
newDummySocketPair :: IO (SocketConnection, SocketConnection)
newDummySocketPair = do
upstream <- newEmptyMVar
downstream <- newEmptyMVar
let x = SocketConnection {
send=putMVar upstream . BSL.toStrict,
receive=takeMVar downstream,
close=pure ()
}
let y = SocketConnection {
send=putMVar downstream . BSL.toStrict,
receive=takeMVar upstream,
close=pure ()
}
pure (x, y)
-- * Internal
--
-- ** Protocol generator helpers
functionArgumentTypes :: RpcFunction -> Q [Type]
......@@ -579,8 +512,3 @@ withAsyncLinked inner outer = withAsync inner $ \task -> link task >> outer task
withAsyncLinked_ :: IO a -> IO b -> IO b
withAsyncLinked_ x = withAsyncLinked x . const
-- | Reimplementation of 'openSocket' from the 'network'-package, which got introduced in version 3.1.2.0. Should be removed later.
openSocket :: Socket.AddrInfo -> IO Socket.Socket
openSocket addr = Socket.socket (Socket.addrFamily addr) (Socket.addrSocketType addr) (Socket.addrProtocol addr)
module Network.Rpc.Connection where
import Control.Concurrent (threadDelay)
import Control.Concurrent.Async (Async, async, cancel, link, waitCatch, withAsync)
import Control.Concurrent.MVar
import Control.Exception (Exception(..), SomeException, bracketOnError, finally, throwIO, bracketOnError, onException)
import Control.Monad ((>=>), unless, forM_)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import Data.List (intercalate)
import GHC.IO (unsafeUnmask)
import qualified Network.Socket as Socket
import qualified Network.Socket.ByteString as Socket
import qualified Network.Socket.ByteString.Lazy as SocketL
import Prelude
-- | Abstraction over a bidirectional stream connection (e.g. a socket), to be able to switch to different communication channels (e.g. stdin/stdout or a dummy implementation for unit tests).
data SocketConnection = SocketConnection {
send :: BSL.ByteString -> IO (),
receive :: IO BS.ByteString,
close :: IO ()
}
class IsSocketConnection a where
toSocketConnection :: a -> SocketConnection
instance IsSocketConnection SocketConnection where
toSocketConnection = id
instance IsSocketConnection Socket.Socket where
toSocketConnection sock = SocketConnection {
send=SocketL.sendAll sock,
receive=Socket.recv sock 4096,
close=Socket.gracefulClose sock 2000
}
newtype ConnectionFailed = ConnectionFailed [(Socket.AddrInfo, SomeException)]
deriving (Show)
instance Exception ConnectionFailed where
displayException (ConnectionFailed attemts) = "Connection attempts failed:\n" <> intercalate "\n" (map (\(addr, err) -> show (Socket.addrAddress addr) <> ": " <> displayException err) attemts)
connectTCP :: Socket.HostName -> Socket.ServiceName -> IO Socket.Socket
connectTCP host port = do
-- 'getAddrInfo' either pures a non-empty list or throws an exception
(best:others) <- Socket.getAddrInfo (Just hints) (Just host) (Just port)
connectTasksMVar <- newMVar []
sockMVar <- newEmptyMVar
let
spawnConnectTask :: Socket.AddrInfo -> IO ()
spawnConnectTask = \addr -> modifyMVar_ connectTasksMVar $ \old -> (:old) . (addr,) <$> connectTask addr
-- Race more connections (a missed TCP SYN will result in 3s wait before a retransmission; IPv6 might be broken)
-- Inspired by a similar implementation in browsers
raceConnections :: IO ()
raceConnections = do
spawnConnectTask best
threadDelay 200000
-- Give the "best" address another try, in case the TCP SYN gets dropped (kernel retry interval can be multiple seconds long)
spawnConnectTask best
threadDelay 100000
-- Try to connect to all other resolved addresses to prevent waiting for e.g. a long IPv6 connection timeout
forM_ others spawnConnectTask
-- Wait for all tasks to complete, throw an exception if all connections failed
connectTasks <- readMVar connectTasksMVar
results <- mapM (\(addr, task) -> (addr,) <$> waitCatch task) connectTasks
forM_ (collect results) (throwIO . ConnectionFailed . reverse)
collect :: [(Socket.AddrInfo, Either SomeException ())] -> Maybe [(Socket.AddrInfo, SomeException)]
collect ((_, Right ()):_) = Nothing
collect ((addr, Left ex):xs) = ((addr, ex):) <$> collect xs
collect [] = Just []
connectTask :: Socket.AddrInfo -> IO (Async ())
connectTask addr = async $ do
sock <- connect addr
isFirst <- tryPutMVar sockMVar sock
unless isFirst $ Socket.close sock
-- The 'raceConnections'-async is 'link'ed to this thread, so 'readMVar' is interrupted when all connection attempts fail
sock <-
(withAsync (unsafeUnmask raceConnections) (link >=> const (readMVar sockMVar))
`finally` (mapM_ (cancel . snd) =<< readMVar connectTasksMVar))
`onException` (mapM_ Socket.close =<< tryTakeMVar sockMVar)
-- As soon as we have an open connection, stop spawning more connections
pure sock
where
hints :: Socket.AddrInfo
hints = Socket.defaultHints { Socket.addrFlags = [Socket.AI_ADDRCONFIG], Socket.addrSocketType = Socket.Stream }
connect :: Socket.AddrInfo -> IO Socket.Socket
connect addr = bracketOnError (openSocket addr) Socket.close $ \sock -> do
Socket.withFdSocket sock Socket.setCloseOnExecIfNeeded
Socket.connect sock $ Socket.addrAddress addr
pure sock
newDummySocketPair :: IO (SocketConnection, SocketConnection)
newDummySocketPair = do
upstream <- newEmptyMVar
downstream <- newEmptyMVar
let x = SocketConnection {
send=putMVar upstream . BSL.toStrict,
receive=takeMVar downstream,
close=pure ()
}
let y = SocketConnection {
send=putMVar downstream . BSL.toStrict,
receive=takeMVar upstream,
close=pure ()
}
pure (x, y)
-- | Reimplementation of 'openSocket' from the 'network'-package, which got introduced in version 3.1.2.0. Should be removed later.
openSocket :: Socket.AddrInfo -> IO Socket.Socket
openSocket addr = Socket.socket (Socket.addrFamily addr) (Socket.addrSocketType addr) (Socket.addrProtocol addr)
module Network.Rpc.Multiplexer (
SocketConnection(..),
IsSocketConnection(..),
ChannelId,
MessageId,
MessageLength,
......@@ -34,30 +32,11 @@ import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import qualified Data.HashMap.Strict as HM
import Data.Word
import qualified Network.Socket as Socket
import qualified Network.Socket.ByteString as Socket
import qualified Network.Socket.ByteString.Lazy as SocketL
import Network.Rpc.Connection
import Prelude
import GHC.Generics
import System.IO (hPutStrLn, stderr)
-- | Abstraction over a socket connection, to be able to switch to different communication channels (e.g. the dummy implementation for unit tests).
data SocketConnection = SocketConnection {
send :: BSL.ByteString -> IO (),
receive :: IO BS.ByteString,
close :: IO ()
}
class IsSocketConnection a where
toSocketConnection :: a -> SocketConnection
instance IsSocketConnection SocketConnection where
toSocketConnection = id
instance IsSocketConnection Socket.Socket where
toSocketConnection sock = SocketConnection {
send=SocketL.sendAll sock,
receive=Socket.recv sock 4096,
close=Socket.gracefulClose sock 2000
}
type ChannelId = Word64
type MessageId = Word64
type MessageLength = Word64
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment