diff --git a/quasar-wayland.cabal b/quasar-wayland.cabal index 82a0113e133aea2bbed7ca2748c55f825a125a85..b767d0c80cc18410be01d5d07949dc114f4f212f 100644 --- a/quasar-wayland.cabal +++ b/quasar-wayland.cabal @@ -86,6 +86,7 @@ library Quasar.Wayland.Protocol.Core Quasar.Wayland.Utils.InlineC Quasar.Wayland.Utils.SharedMemory + Quasar.Wayland.Utils.Socket build-depends: base >=4.7 && <5, binary, diff --git a/src/Quasar/Wayland/Connection.hs b/src/Quasar/Wayland/Connection.hs index afa515f969577410974ec37719d2680ee924efdd..a4285c907d7f7db510a99b8e3b255e7ca2368e5d 100644 --- a/src/Quasar/Wayland/Connection.hs +++ b/src/Quasar/Wayland/Connection.hs @@ -1,3 +1,6 @@ +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} + module Quasar.Wayland.Connection ( WaylandConnection, newWaylandConnection, @@ -5,8 +8,13 @@ module Quasar.Wayland.Connection ( import Control.Concurrent.STM import Control.Monad.Catch +import Data.Bits ((.&.)) import Data.ByteString qualified as BS +import Data.ByteString.Internal (createUptoN) import Data.ByteString.Lazy qualified as BSL +import Foreign.Storable (sizeOf) +import Language.C.Inline qualified as C +import Language.C.Inline.Unsafe qualified as CU import Network.Socket (Socket) import Network.Socket qualified as Socket import Network.Socket.ByteString qualified as Socket @@ -14,6 +22,18 @@ import Network.Socket.ByteString.Lazy qualified as SocketL import Quasar import Quasar.Prelude import Quasar.Wayland.Protocol +import Quasar.Wayland.Utils.Socket +import System.Posix.Types (Fd) + + +C.include "<sys/socket.h>" + +maxFds :: C.CInt +maxFds = 28 -- from wayland (connection.c) + +cmsgBufferSize :: Int +cmsgBufferSize = fromIntegral [CU.pure|int { CMSG_LEN($(int maxFds) * sizeof(int32_t)) }|] + data WaylandConnection s = WaylandConnection { @@ -67,22 +87,54 @@ connectionThread connection work = asyncWithHandler traceAndDisposeConnection $ sendThread :: WaylandConnection s -> IO () sendThread connection = forever do - bytes <- takeOutbox connection.protocolHandle + (msg, fds) <- takeOutbox connection.protocolHandle + + let msgLength = fromIntegral (BSL.length msg) - traceIO $ "Sending " <> show (BSL.length bytes) <> " bytes" - SocketL.sendAll connection.socket bytes + traceIO $ "Sending " <> show msgLength <> " bytes" + + -- TODO limit max fds + send msgLength (BSL.toChunks msg) fds + + where + send :: Int -> [BS.ByteString] -> [Fd] -> IO () + send remaining chunks fds = do + -- TODO add MSG_NOSIGNAL (not exposed by `network`) + sent <- sendMsg connection.socket chunks (Socket.encodeCmsg <$> fds) mempty + let nowRemaining = remaining - sent + when (nowRemaining > 0) do + send nowRemaining (drop sent chunks) [] + + drop :: Int -> [BS.ByteString] -> [BS.ByteString] + drop _ [] = [] + drop amount (chunk:chunks) = + if (amount < BS.length chunk) + then (BS.drop amount chunk : chunks) + else drop (amount - BS.length chunk) chunks receiveThread :: IsSide s => WaylandConnection s -> IO () receiveThread connection = forever do - bytes <- Socket.recv connection.socket 4096 + -- TODO add MSG_CMSG_CLOEXEC (not exposed by `network`) + (chunk, cmsgs, flags) <- recvMsg connection.socket 4096 cmsgBufferSize mempty + + let fds = catMaybes (Socket.decodeCmsg @Fd <$> cmsgs) - when (BS.null bytes) do + when (flags .&. Socket.MSG_CTRUNC > 0) do + -- TODO close fds + fail "Wayland connection: Ancillary data was truncated" + + when (length fds /= length cmsgs) do + -- TODO close fds + fail "Wayland connection: Received unexpected ancillary message (only SCM_RIGHTS is supported)" + + when (BS.null chunk) do throwM SocketClosed - traceIO $ "Received " <> show (BS.length bytes) <> " bytes" + traceIO $ "Received " <> show (BS.length chunk) <> " bytes" + + feedInput connection.protocolHandle chunk fds - feedInput connection.protocolHandle bytes closeConnection :: WaylandConnection s -> IO () closeConnection connection = Socket.close connection.socket diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs index 60e18081717223161df4a0575acfc1aaddee9690..0445a6f52a88ae6850c2341d020d88d697d337f3 100644 --- a/src/Quasar/Wayland/Protocol/Core.hs +++ b/src/Quasar/Wayland/Protocol/Core.hs @@ -47,6 +47,7 @@ module Quasar.Wayland.Protocol.Core ( bindNewObject, getObject, lookupObject, + buildMessage, -- * wl_display interface handleWlDisplayError, @@ -76,15 +77,19 @@ import Data.ByteString (ByteString) import Data.ByteString qualified as BS import Data.ByteString.Lazy qualified as BSL import Data.ByteString.UTF8 qualified as BSUTF8 +import Data.Foldable (toList) import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict qualified as HM import Data.Proxy +import Data.Sequence (Seq) +import Data.Sequence qualified as Seq import Data.String (IsString(..)) import Data.Typeable (Typeable, cast) import Data.Void (absurd) import GHC.Conc (unsafeIOToSTM) import GHC.TypeLits import Quasar.Prelude +import System.Posix.Types (Fd(Fd)) newtype ObjectId (j :: Symbol) = ObjectId Word32 @@ -132,24 +137,32 @@ toString :: WlString -> String toString (WlString bs) = BSUTF8.toString bs +data MessagePart = MessagePart Put Int (Seq Fd) + +instance Semigroup MessagePart where + (MessagePart px lx fx) <> (MessagePart py ly fy) = MessagePart (px <> py) (lx + ly) (fx <> fy) + +instance Monoid MessagePart where + mempty = MessagePart mempty 0 mempty + class (Eq a, Show a) => WireFormat a where - putArgument :: a -> ProtocolM s (Put, Int) + putArgument :: a -> Either SomeException MessagePart getArgument :: Get (ProtocolM s a) showArgument :: a -> String instance WireFormat Int32 where - putArgument x = pure (putInt32host x, 4) + putArgument x = pure $ MessagePart (putInt32host x) 4 mempty getArgument = pure <$> getInt32host showArgument = show instance WireFormat Word32 where - putArgument x = pure (putWord32host x, 4) + putArgument x = pure $ MessagePart (putWord32host x) 4 mempty getArgument = pure <$> getWord32host showArgument = show instance WireFormat Fixed where - putArgument (Fixed repr) = pure (putWord32host repr, 4) + putArgument (Fixed repr) = pure $ MessagePart (putWord32host repr) 4 mempty getArgument = pure . Fixed <$> getWord32host showArgument = show @@ -164,12 +177,12 @@ instance WireFormat BS.ByteString where showArgument array = "[array " <> show (BS.length array) <> "B]" instance KnownSymbol j => WireFormat (ObjectId (j :: Symbol)) where - putArgument (ObjectId oId) = pure (putWord32host oId, 4) + putArgument (ObjectId oId) = pure $ MessagePart (putWord32host oId) 4 mempty getArgument = pure . ObjectId <$> getWord32host showArgument (ObjectId oId) = symbolVal @j Proxy <> "@" <> show oId instance WireFormat GenericObjectId where - putArgument (GenericObjectId oId) = pure (putWord32host oId, 4) + putArgument (GenericObjectId oId) = pure $ MessagePart (putWord32host oId) 4 mempty getArgument = pure . GenericObjectId <$> getWord32host showArgument oId = "[unknown]@" <> show oId @@ -180,17 +193,17 @@ instance KnownSymbol j => WireFormat (NewId (j :: Symbol)) where instance WireFormat GenericNewId where putArgument (GenericNewId interface version newId) = do - (put1, s1) <- putArgument interface - (put2, s2) <- putArgument version - (put3, s3) <- putArgument newId - pure (put1 >> put2 >> put3, s1 + s2 + s3) + p1 <- putArgument interface + p2 <- putArgument version + p3 <- putArgument newId + pure (p1 <> p2 <> p3) getArgument = GenericNewId <<$>> getArgument <<*>> getArgument <<*>> getArgument showArgument (GenericNewId interface version newId) = mconcat ["new ", toString interface, "[v", show version, "]@", show newId] -instance WireFormat Void where - putArgument = absurd - getArgument = pure <$> get - showArgument = absurd +instance WireFormat Fd where + putArgument fd = pure (MessagePart mempty 0 (Seq.singleton fd)) + getArgument = undefined + showArgument (Fd fd) = "fd@" <> show fd -- | Class for a proxy type (in the haskell sense) that describes a Wayland interface. @@ -251,7 +264,7 @@ class ( getWireDown :: forall s i. IsInterfaceSide s i => Object s i -> Opcode -> Get (ProtocolM s (WireDown s i)) getWireDown = getMessage @(WireDown s i) -putWireUp :: forall s i. IsInterfaceSide s i => Object s i -> WireUp s i -> ProtocolM s (Opcode, [(Put, Int)]) +putWireUp :: forall s i. IsInterfaceSide s i => Object s i -> WireUp s i -> Either SomeException (Opcode, MessagePart) putWireUp _ = putMessage @(WireUp s i) @@ -329,13 +342,16 @@ instance IsObjectSide (SomeObject s) where class (Eq a, Show a) => IsMessage a where opcodeName :: Opcode -> Maybe String getMessage :: IsInterface i => Object s i -> Opcode -> Get (ProtocolM s a) - putMessage :: a -> ProtocolM s (Opcode, [(Put, Int)]) + putMessage :: a -> Either SomeException (Opcode, MessagePart) instance IsMessage Void where opcodeName _ = Nothing getMessage = invalidOpcode putMessage = absurd +buildMessage :: Opcode -> [Either SomeException MessagePart] -> Either SomeException (Opcode, MessagePart) +buildMessage opcode parts = (opcode,) . mconcat <$> sequence parts + invalidOpcode :: IsInterface i => Object s i -> Opcode -> Get a invalidOpcode object opcode = fail $ mconcat [ "Invalid opcode ", show opcode, " on ", toString (objectInterfaceName object), @@ -390,7 +406,9 @@ data ProtocolState (s :: Side) = ProtocolState { bytesReceivedVar :: TVar Int64, bytesSentVar :: TVar Int64, inboxDecoderVar :: TVar (Decoder RawMessage), + inboxFdsVar :: TVar (Seq Fd), outboxVar :: TVar (Maybe Put), + outboxFdsVar :: TVar (Seq Fd), objectsVar :: TVar (HashMap GenericObjectId (SomeObject s)), nextIdVar :: TVar Word32 } @@ -425,6 +443,11 @@ stateProtocolVar fn x = do state <- ask lift $ stateTVar (fn state) x +swapProtocolVar :: (ProtocolState s -> TVar a) -> a -> ProtocolM s a +swapProtocolVar fn x = do + state <- ask + lift $ swapTVar (fn state) x + initializeProtocol :: forall s wl_display a. (IsInterfaceSide s wl_display) => (ProtocolHandle s -> MessageHandler s wl_display) @@ -434,7 +457,9 @@ initializeProtocol wlDisplayMessageHandler initializationAction = do bytesReceivedVar <- newTVar 0 bytesSentVar <- newTVar 0 inboxDecoderVar <- newTVar $ runGetIncremental getRawMessage + inboxFdsVar <- newTVar mempty outboxVar <- newTVar Nothing + outboxFdsVar <- newTVar mempty protocolKey <- unsafeIOToSTM newUnique objectsVar <- newTVar $ HM.empty nextIdVar <- newTVar (initialId @s) @@ -452,7 +477,9 @@ initializeProtocol wlDisplayMessageHandler initializationAction = do bytesReceivedVar, bytesSentVar, inboxDecoderVar, + inboxFdsVar, outboxVar, + outboxFdsVar, objectsVar, nextIdVar } @@ -501,12 +528,13 @@ runProtocolM protocol action = either throwM (runReaderT action) =<< readTVar pr -- | Feed the protocol newly received data. -feedInput :: (IsSide s, MonadIO m) => ProtocolHandle s -> ByteString -> m () -feedInput protocol bytes = runProtocolTransaction protocol do +feedInput :: (IsSide s, MonadIO m) => ProtocolHandle s -> ByteString -> [Fd] -> m () +feedInput protocol bytes fds = runProtocolTransaction protocol do -- Exposing MonadIO instead of STM to the outside and using `runProtocolTransaction` here enforces correct exception -- handling. modifyProtocolVar' (.bytesReceivedVar) (+ fromIntegral (BS.length bytes)) modifyProtocolVar (.inboxDecoderVar) (`pushChunk` bytes) + modifyProtocolVar (.inboxFdsVar) (<> Seq.fromList fds) receiveMessages -- | Set the protocol to a failed state, e.g. when the socket closed unexpectedly. @@ -514,13 +542,14 @@ setException :: (Exception e, MonadIO m) => ProtocolHandle s -> e -> m () setException protocol ex = runProtocolTransaction protocol $ throwM ex -- | Take data that has to be sent. Blocks until data is available. -takeOutbox :: MonadIO m => ProtocolHandle s -> m (BSL.ByteString) +takeOutbox :: MonadIO m => ProtocolHandle s -> m (BSL.ByteString, [Fd]) takeOutbox protocol = runProtocolTransaction protocol do mOutboxData <- stateProtocolVar (.outboxVar) (\mOutboxData -> (mOutboxData, Nothing)) outboxData <- maybe (lift retry) pure mOutboxData let sendData = runPut outboxData modifyProtocolVar' (.bytesSentVar) (+ BSL.length sendData) - pure sendData + fds <- swapProtocolVar (.outboxFdsVar) mempty + pure (sendData, toList fds) -- | Create an object. The caller is responsible for sending the 'NewId' immediately (exactly once and before using the @@ -654,16 +683,13 @@ sendMessage object message = do Left msg -> throwM $ ProtocolUsageError $ "Tried to send message to an invalid object: " <> msg Right () -> pure () - (opcode, pairs) <- putWireUp object message - let (putBodyParts, partLengths) = unzip pairs - let putBody = mconcat putBodyParts + (opcode, MessagePart putBody bodyLength fds) <- either throwM pure $ putWireUp object message - let bodyLength = foldr (+) 8 partLengths when (bodyLength > fromIntegral (maxBound :: Word16)) $ throwM $ ProtocolUsageError $ "Tried to send message larger than 2^16 bytes" traceM $ "-> " <> showObjectMessage object message - sendRawMessage $ putHeader opcode bodyLength >> putBody + sendRawMessage (putHeader opcode (8 + bodyLength) >> putBody) fds where oId = genericObjectId object (GenericObjectId objectIdWord) = genericObjectId object @@ -741,12 +767,12 @@ getWaylandBlob = do skipPadding pure string -putWaylandBlob :: BS.ByteString -> ProtocolM s (Put, Int) +putWaylandBlob :: MonadThrow m => BS.ByteString -> m MessagePart putWaylandBlob blob = do when (len > fromIntegral (maxBound :: Word16)) $ throwM $ ProtocolUsageError $ "Tried to send string or array larger than 2^16 bytes" - pure (putBlob, 4 + len + pad) + pure $ MessagePart putBlob (4 + len + pad) mempty where -- Total data length including null byte len = BS.length blob + 1 @@ -768,5 +794,7 @@ padding :: Integral a => a -> a padding size = ((4 - (size `mod` 4)) `mod` 4) -sendRawMessage :: Put -> ProtocolM s () -sendRawMessage x = modifyProtocolVar (.outboxVar) (Just . maybe x (<> x)) +sendRawMessage :: Put -> Seq Fd -> ProtocolM s () +sendRawMessage x fds = do + modifyProtocolVar (.outboxVar) (Just . maybe x (<> x)) + modifyProtocolVar (.outboxFdsVar) (<> fds) diff --git a/src/Quasar/Wayland/Protocol/TH.hs b/src/Quasar/Wayland/Protocol/TH.hs index c0690241f72ab919eb00c86c9fe3d045deda6738..4cb304a426fcf2bb551b4079113b72cc3920e7b1 100644 --- a/src/Quasar/Wayland/Protocol/TH.hs +++ b/src/Quasar/Wayland/Protocol/TH.hs @@ -12,6 +12,7 @@ import Language.Haskell.TH import Language.Haskell.TH.Syntax (addDependentFile) import Quasar.Prelude import Quasar.Wayland.Protocol.Core +import System.Posix.Types (Fd(Fd)) import Text.Read (readEither) import Text.XML.Light @@ -437,10 +438,9 @@ isMessageInstanceD t msgs = instanceD (pure []) [t|IsMessage $t|] [opcodeNameD, putMessageClauseD msg = clause [msgConP msg] (normalB (putMessageE msg.msgSpec.arguments)) [] where putMessageE :: [ArgumentSpec] -> Q Exp - putMessageE args = [|($(litE $ integerL $ fromIntegral msg.msgSpec.opcode), ) <$> $(putMessageBodyE args)|] + putMessageE args = [|buildMessage $(litE $ integerL $ fromIntegral msg.msgSpec.opcode) $(putMessageBodyE args)|] putMessageBodyE :: [ArgumentSpec] -> Q Exp - putMessageBodyE [] = [|pure []|] - putMessageBodyE args = [|sequence $(listE ((\arg -> [|putArgument @($(argumentWireType arg)) $(msgArgE msg arg)|]) <$> args))|] + putMessageBodyE args = [|$(listE ((\arg -> [|putArgument @($(argumentWireType arg)) $(msgArgE msg arg)|]) <$> args))|] derivingEq :: Q DerivClause @@ -470,7 +470,7 @@ liftArgumentWireType (ObjectArgument iName) = [t|ObjectId $(litT (strTyLit iName liftArgumentWireType GenericObjectArgument = [t|GenericObjectId|] liftArgumentWireType (NewIdArgument iName) = [t|NewId $(litT (strTyLit iName))|] liftArgumentWireType GenericNewIdArgument = [t|GenericNewId|] -liftArgumentWireType FdArgument = [t|Void|] -- TODO +liftArgumentWireType FdArgument = [t|Fd|] -- * Generic TH utilities diff --git a/src/Quasar/Wayland/Utils/Socket.hs b/src/Quasar/Wayland/Utils/Socket.hs new file mode 100644 index 0000000000000000000000000000000000000000..c74a955a550c12df4f0779dea787db71762cbe16 --- /dev/null +++ b/src/Quasar/Wayland/Utils/Socket.hs @@ -0,0 +1,56 @@ +module Quasar.Wayland.Utils.Socket ( + recvMsg, + sendMsg, +) where + +import Data.ByteString qualified as BS +import Data.ByteString.Internal (ByteString(PS), create) +import Foreign +import Network.Socket +import Network.Socket.Address qualified as SA +import Network.Socket.Internal (zeroMemory) +import Quasar.Prelude + + +instance SA.SocketAddress () where + sizeOfSocketAddress _ = 0 + peekSocketAddress _ptr = pure () + pokeSocketAddress _prt () = pure () + + + +withBufSizs :: [ByteString] -> ([(Ptr Word8, Int)] -> IO a) -> IO a +withBufSizs bss0 f = loop bss0 id + where + loop [] !build = f $ build [] + loop (PS fptr off len:bss) !build = withForeignPtr fptr $ \ptr -> do + let !ptr' = ptr `plusPtr` off + loop bss (build . ((ptr',len) :)) + +-- | Send data to the connected socket using sendmsg(2). +sendMsg :: Socket -- ^ Socket + -> [BS.ByteString] -- ^ Data to be sent + -> [Cmsg] -- ^ Control messages + -> MsgFlag -- ^ Message flags + -> IO Int -- ^ The length actually sent +sendMsg _ [] _ _ = pure 0 +sendMsg s bss cmsgs flags = withBufSizs bss $ \bufsizs -> + SA.sendBufMsg s () bufsizs cmsgs flags + +-- | Receive data from the connected socket using recvmsg(2). +recvMsg :: Socket -- ^ Socket + -> Int -- ^ The maximum length of data to be received + -- If the total length is not large enough, + -- 'MSG_TRUNC' is returned + -> Int -- ^ The buffer size for control messages. + -- If the length is not large enough, + -- 'MSG_CTRUNC' is returned + -> MsgFlag -- ^ Message flags + -> IO (BS.ByteString, [Cmsg], MsgFlag) -- ^ Source address, received data, control messages and message flags +recvMsg s siz clen flags = do + bs@(PS fptr _ _) <- create siz $ \ptr -> zeroMemory ptr (fromIntegral siz) + withForeignPtr fptr $ \ptr -> do + ((),len,cmsgs,flags') <- SA.recvBufMsg s [(ptr,siz)] clen flags + let bs' | len < siz = PS fptr 0 len + | otherwise = bs + pure (bs', cmsgs, flags')