diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs index abd711c6d18f9b17542c889418e5d97754614996..cb0dd2545ee76764e24e7774d2e75d8fa173d3f1 100644 --- a/src/Quasar/Wayland/Protocol/Core.hs +++ b/src/Quasar/Wayland/Protocol/Core.hs @@ -59,6 +59,7 @@ import Control.Concurrent.STM import Control.Monad (replicateM_) import Control.Monad.Catch import Control.Monad.Reader (ReaderT, runReaderT, ask, lift) +import Data.Bifunctor qualified as Bifunctor import Data.Binary import Data.Binary.Get import Data.Binary.Put @@ -104,6 +105,7 @@ instance Show Fixed where -- -- Instances and functions in this library assume UTF-8, but the original data is also available by deconstructing. newtype WlString = WlString BS.ByteString + deriving newtype (Eq, Hashable) instance Show WlString where show = show . toString @@ -139,62 +141,62 @@ isNewId _ = False class (Eq (Argument a), Show (Argument a)) => WireFormat a where type Argument a - putArgument :: Argument a -> PutM () - getArgument :: Get (Argument a) + putArgument :: Argument a -> ProtocolM s (Put, Int) + getArgument :: Get (ProtocolM s (Argument a)) showArgument :: Argument a -> String instance WireFormat 'IntArgument where type Argument 'IntArgument = Int32 - putArgument = putInt32host - getArgument = getInt32host + putArgument x = pure (putInt32host x, 4) + getArgument = pure <$> getInt32host showArgument = show instance WireFormat 'UIntArgument where type Argument 'UIntArgument = Word32 - putArgument = putWord32host - getArgument = getWord32host + putArgument x = pure (putWord32host x, 4) + getArgument = pure <$> getWord32host showArgument = show instance WireFormat 'FixedArgument where type Argument 'FixedArgument = Fixed - putArgument (Fixed repr) = putWord32host repr - getArgument = Fixed <$> getWord32host + putArgument (Fixed repr) = pure (putWord32host repr, 4) + getArgument = pure . Fixed <$> getWord32host showArgument = show instance WireFormat 'StringArgument where - type Argument 'StringArgument = BS.ByteString - putArgument = putWaylandBlob - getArgument = getWaylandBlob + type Argument 'StringArgument = WlString + putArgument (WlString x) = pure $ putWaylandBlob x + getArgument = pure . WlString <$> getWaylandBlob showArgument = show instance WireFormat 'ArrayArgument where type Argument 'ArrayArgument = BS.ByteString - putArgument = putWaylandBlob - getArgument = getWaylandBlob + putArgument x = pure $ putWaylandBlob x + getArgument = pure <$> getWaylandBlob showArgument array = "[array " <> show (BS.length array) <> "B]" instance KnownSymbol j => WireFormat (ObjectId (j :: Symbol)) where type Argument (ObjectId j) = ObjectId j - putArgument (ObjectId oId) = putWord32host oId - getArgument = ObjectId <$> getWord32host + putArgument (ObjectId oId) = pure (putWord32host oId, 4) + getArgument = pure . ObjectId <$> getWord32host showArgument (ObjectId oId) = symbolVal @j Proxy <> "@" <> show oId instance WireFormat 'GenericObjectArgument where type Argument 'GenericObjectArgument = GenericObjectId - putArgument = putWord32host - getArgument = getWord32host + putArgument x = pure (putWord32host x, 4) + getArgument = pure <$> getWord32host showArgument oId = "[unknown]@" <> show oId instance KnownSymbol j => WireFormat (NewId (j :: Symbol)) where type Argument (NewId j) = NewId j - putArgument (NewId newId) = putWord32host newId - getArgument = NewId <$> getWord32host + putArgument (NewId newId) = pure (putWord32host newId, 4) + getArgument = pure . NewId <$> getWord32host showArgument (NewId newId) = "new " <> symbolVal @j Proxy <> "@" <> show newId instance WireFormat 'GenericNewIdArgument where type Argument 'GenericNewIdArgument = GenericNewId - putArgument (GenericNewId newId) = putWord32host newId - getArgument = GenericNewId <$> getWord32host + putArgument (GenericNewId newId) = pure (putWord32host newId, 4) + getArgument = pure . GenericNewId <$> getWord32host showArgument newId = "new [unknown]@" <> show newId instance WireFormat 'FdArgument where @@ -245,10 +247,10 @@ class ( => IsInterfaceSide (s :: Side) i -getDown :: forall s i. IsInterfaceSide s i => Object s i -> Opcode -> Get (Down s i) +getDown :: forall s i. IsInterfaceSide s i => Object s i -> Opcode -> Get (ProtocolM s (Down s i)) getDown = getMessage @(Down s i) -putUp :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> PutM Opcode +putUp :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolM s (Opcode, [(Put, Int)]) putUp _ = putMessage @(Up s i) @@ -307,8 +309,8 @@ instance IsObjectSide (SomeObject s) where class (Eq a, Show a) => IsMessage a where opcodeName :: Opcode -> Maybe String - getMessage :: IsInterface i => Object s i -> Opcode -> Get a - putMessage :: a -> PutM Opcode + getMessage :: IsInterface i => Object s i -> Opcode -> Get (ProtocolM s a) + putMessage :: a -> ProtocolM s (Opcode, [(Put, Int)]) instance IsMessage Void where opcodeName _ = Nothing @@ -413,6 +415,11 @@ modifyProtocolVar fn x = do state <- ask lift $ modifyTVar (fn state) x +modifyProtocolVar' :: (ProtocolState s -> TVar a) -> (a -> a) -> ProtocolM s () +modifyProtocolVar' fn x = do + state <- ask + lift $ modifyTVar' (fn state) x + stateProtocolVar :: (ProtocolState s -> TVar a) -> (a -> (r, a)) -> ProtocolM s r stateProtocolVar fn x = do state <- ask @@ -484,7 +491,7 @@ runProtocolM protocol action = either throwM (runReaderT action) =<< readTVar pr feedInput :: (IsSide s, MonadIO m, MonadThrow m) => ProtocolHandle s -> ByteString -> m () feedInput protocol bytes = runProtocolTransaction protocol do -- Exposing MonadIO instead of STM to the outside and using `runProtocolTransaction` here enforces correct exception handling. - modifyProtocolVar (.bytesReceivedVar) (+ fromIntegral (BS.length bytes)) + modifyProtocolVar' (.bytesReceivedVar) (+ fromIntegral (BS.length bytes)) modifyProtocolVar (.inboxDecoderVar) (`pushChunk` bytes) receiveMessages @@ -498,7 +505,7 @@ takeOutbox protocol = runProtocolTransaction protocol do mOutboxData <- stateProtocolVar (.outboxVar) (\mOutboxData -> (mOutboxData, Nothing)) outboxData <- maybe (lift retry) pure mOutboxData let sendData = runPut outboxData - modifyProtocolVar (.bytesSentVar) (+ BSL.length sendData) + modifyProtocolVar' (.bytesSentVar) (+ BSL.length sendData) pure sendData @@ -540,25 +547,28 @@ newObjectFromId (NewId oId) callback = do -- | Sends a message without checking any ids or creating proxy objects objects. (TODO) sendMessage :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolM s () sendMessage object message = do + (opcode, pairs) <- putUp object message + let (putBodyParts, partLengths) = unzip pairs + let putBody = mconcat putBodyParts + let bodyLength = foldr (+) 0 partLengths + let body = runPut putBody traceM $ "-> " <> showObjectMessage object message - sendRawMessage messageWithHeader + sendRawMessage $ messageWithHeader opcode body where - body :: BSL.ByteString - opcode :: Opcode - (opcode, body) = runPutM $ putUp object message - messageWithHeader :: Put - messageWithHeader = do + messageWithHeader :: Opcode -> BSL.ByteString -> Put + messageWithHeader opcode body = do putWord32host $ objectId object putWord32host $ (fromIntegral msgSize `shiftL` 16) .|. fromIntegral opcode putLazyByteString body - msgSize :: Word16 - msgSize = - if msgSizeInteger <= fromIntegral (maxBound :: Word16) - then fromIntegral msgSizeInteger - else error "Message too large" - -- TODO: body length should be returned from `putMessage`, instead of realizing it to a ByteString here - msgSizeInteger :: Integer - msgSizeInteger = 8 + fromIntegral (BSL.length body) + where + msgSize :: Word16 + msgSize = + if msgSizeInteger <= fromIntegral (maxBound :: Word16) + then fromIntegral msgSizeInteger + else error "Message too large" + -- TODO: body length should be returned from `putMessage`, instead of realizing it to a ByteString here + msgSizeInteger :: Integer + msgSizeInteger = 8 + fromIntegral (BSL.length body) receiveMessages :: IsSide s => ProtocolM s () @@ -591,8 +601,8 @@ getMessageAction -> Opcode -> Get (ProtocolM s ()) getMessageAction object@(Object _ objectHandler) opcode = do - message <- getDown object opcode - pure $ handleMessage objectHandler object message + verifyMessage <- getDown object opcode + pure $ handleMessage objectHandler object =<< verifyMessage type RawMessage = (GenericObjectId, Opcode, BSL.ByteString) @@ -627,13 +637,18 @@ getWaylandBlob = do skipPadding pure string -putWaylandBlob :: BS.ByteString -> Put -putWaylandBlob blob = do - let size = BS.length blob - putWord32host (fromIntegral (size + 1)) - putByteString blob - putWord8 0 - replicateM_ (padding size) (putWord8 0) +putWaylandBlob :: BS.ByteString -> (Put, Int) +putWaylandBlob blob = (putBlob, 4 + len + pad) + where + -- Total data length including null byte + len = BS.length blob + 1 + -- Padding length + pad = padding len + putBlob = do + putWord32host (fromIntegral (len + 1)) + putByteString blob + putWord8 0 + replicateM_ pad (putWord8 0) skipPadding :: Get () diff --git a/src/Quasar/Wayland/Protocol/Display.hs b/src/Quasar/Wayland/Protocol/Display.hs index 3b58452c598bd80387991f93921e8985b8aefd65..85a301195af85abdc378f7dba161becfc025dceb 100644 --- a/src/Quasar/Wayland/Protocol/Display.hs +++ b/src/Quasar/Wayland/Protocol/Display.hs @@ -3,7 +3,6 @@ module Quasar.Wayland.Protocol.Display ( ) where import Control.Monad.Catch -import Data.ByteString.UTF8 qualified as BS import Data.HashMap.Strict qualified as HM import Quasar.Prelude import Quasar.Wayland.Protocol.Core @@ -19,7 +18,7 @@ clientWlDisplayCallback :: IsInterfaceSide 'Client I_wl_display => Callback 'Cli clientWlDisplayCallback = internalFnCallback handler where -- | wl_display is specified to never change, so manually specifying the callback is safe - handler :: Object 'Client I_wl_display -> E_wl_display -> ProtocolM 'Client () + handler :: Object 'Client I_wl_display -> WireEvent_wl_display -> ProtocolM 'Client () -- TODO parse oId - handler _ (E_wl_display_error oId code message) = throwM $ ServerError code (BS.toString message) - handler _ (E_wl_display_delete_id deletedId) = pure () -- TODO confirm delete + handler _ (WireEvent_wl_display_error oId code message) = throwM $ ServerError code (toString message) + handler _ (WireEvent_wl_display_delete_id deletedId) = pure () -- TODO confirm delete diff --git a/src/Quasar/Wayland/Protocol/TH.hs b/src/Quasar/Wayland/Protocol/TH.hs index 9f27a5f03eed048ce65960f08b12a2d75f4d99ea..85f60e92b7e163753e18ff57595945551ec96a7d 100644 --- a/src/Quasar/Wayland/Protocol/TH.hs +++ b/src/Quasar/Wayland/Protocol/TH.hs @@ -97,15 +97,15 @@ interfaceDecs interface = do rT :: Q Type rT = if length interface.requests > 0 then conT rTypeName else [t|Void|] rTypeName :: Name - rTypeName = mkName $ "R_" <> interface.name + rTypeName = mkName $ "WireRequest_" <> interface.name rConName :: RequestSpec -> Name - rConName (RequestSpec request) = mkName $ "R_" <> interface.name <> "_" <> request.name + rConName (RequestSpec request) = mkName $ "WireRequest_" <> interface.name <> "_" <> request.name eT :: Q Type eT = if length interface.events > 0 then conT eTypeName else [t|Void|] eTypeName :: Name - eTypeName = mkName $ "E_" <> interface.name + eTypeName = mkName $ "WireEvent_" <> interface.name eConName :: EventSpec -> Name - eConName (EventSpec event) = mkName $ "E_" <> interface.name <> "_" <> event.name + eConName (EventSpec event) = mkName $ "WireEvent_" <> interface.name <> "_" <> event.name requestContext :: RequestSpec -> MessageContext requestContext req@(RequestSpec msgSpec) = MessageContext { msgInterfaceT = iT, @@ -246,7 +246,7 @@ isMessageInstanceD t msgs = instanceD (pure []) [t|IsMessage $t|] [opcodeNameD, getMessageClause msg = clause [wildP, litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB getMessageE) [] where getMessageE :: Q Exp - getMessageE = applyA (conE (msg.msgConName)) ((\argT -> [|getArgument @($argT)|]) . argumentSpecType <$> msg.msgSpec.arguments) + getMessageE = applyALifted (conE (msg.msgConName)) ((\argT -> [|getArgument @($argT)|]) . argumentSpecType <$> msg.msgSpec.arguments) getMessageInvalidOpcodeClause :: Q Clause getMessageInvalidOpcodeClause = do let object = mkName "object" @@ -258,10 +258,10 @@ isMessageInstanceD t msgs = instanceD (pure []) [t|IsMessage $t|] [opcodeNameD, putMessageClauseD msg = clause [msgConP msg] (normalB (putMessageE msg.msgSpec.arguments)) [] where putMessageE :: [ArgumentSpec] -> Q Exp - putMessageE [] = opcodeE - putMessageE args = doE (((\arg -> noBindS [|putArgument @($(argumentSpecType arg)) $(msgArgE msg arg)|]) <$> args) <> [noBindS opcodeE]) - opcodeE :: Q Exp - opcodeE = [|pure $(litE $ integerL $ fromIntegral msg.msgSpec.opcode)|] + putMessageE args = [|($(litE $ integerL $ fromIntegral msg.msgSpec.opcode), ) <$> $(putMessageBodyE args)|] + putMessageBodyE :: [ArgumentSpec] -> Q Exp + putMessageBodyE [] = [|pure []|] + putMessageBodyE args = [|sequence $(listE ((\arg -> [|putArgument @($(argumentSpecType arg)) $(msgArgE msg arg)|]) <$> args))|] derivingEq :: Q DerivClause @@ -309,6 +309,12 @@ applyM con [] = con applyM con args = [|join $(applyA con args)|] +-- | (a -> b -> c -> d) -> [f (g a), f (g b), f (g c)] -> f (g d) +applyALifted :: Q Exp -> [Q Exp] -> Q Exp +applyALifted con [] = [|pure $ pure $con|] +applyALifted con (monadicE:monadicEs) = foldl (\x y -> [|$x <<*>> $y|]) [|$con <<$>> $monadicE|] monadicEs + + -- * XML parser parseProtocol :: MonadFail m => BS.ByteString -> m ProtocolSpec diff --git a/src/Quasar/Wayland/Registry.hs b/src/Quasar/Wayland/Registry.hs index 0bb0ef61803516ecf1679b6f35adeb7124384564..0f5013365204f58c0828ae19108e3a3d05320db1 100644 --- a/src/Quasar/Wayland/Registry.hs +++ b/src/Quasar/Wayland/Registry.hs @@ -15,7 +15,7 @@ import Quasar.Wayland.Protocol.Generated data ClientRegistry = ClientRegistry { wlRegistry :: Object 'Client I_wl_registry, - globalsVar :: TVar (HM.HashMap Word32 (BS.ByteString, Word32)) + globalsVar :: TVar (HM.HashMap Word32 (WlString, Word32)) } createClientRegistry :: Object 'Client I_wl_display -> ProtocolM 'Client ClientRegistry @@ -23,7 +23,7 @@ createClientRegistry wlDisplay = mfix \clientRegistry -> do globalsVar <- lift $ newTVar HM.empty (wlRegistry, newId) <- newObject @'Client @I_wl_registry (traceCallback (callback clientRegistry)) - sendMessage wlDisplay $ R_wl_display_get_registry newId + sendMessage wlDisplay $ WireRequest_wl_display_get_registry newId pure ClientRegistry { wlRegistry, @@ -34,10 +34,10 @@ createClientRegistry wlDisplay = mfix \clientRegistry -> do callback clientRegistry = internalFnCallback handler where -- | wl_registry is specified to never change, so manually specifying the callback is safe - handler :: Object 'Client I_wl_registry -> E_wl_registry -> ProtocolM 'Client () - handler _ (E_wl_registry_global name interface version) = do + handler :: Object 'Client I_wl_registry -> WireEvent_wl_registry -> ProtocolM 'Client () + handler _ (WireEvent_wl_registry_global name interface version) = do lift $ modifyTVar clientRegistry.globalsVar (HM.insert name (interface, version)) - handler _ (E_wl_registry_global_remove name) = do + handler _ (WireEvent_wl_registry_global_remove name) = do result <- lift $ stateTVar clientRegistry.globalsVar (swap . lookupDelete name) case result of Nothing -> traceM $ "Invalid global removed by server: " <> show name