diff --git a/src/Quasar/Wayland/Protocol.hs b/src/Quasar/Wayland/Protocol.hs index 6cbd1ad0dce4a447d7093bc0878a7a88cedb8546..3c042cdbaa77e8a20d90eafdabd7d2cac6024087 100644 --- a/src/Quasar/Wayland/Protocol.hs +++ b/src/Quasar/Wayland/Protocol.hs @@ -4,20 +4,13 @@ module Quasar.Wayland.Protocol ( -- "Quasar.Wayland.Protocol.TH". Object, + getMessageHandler, + setMessageHandler, - -- ** Wire types - ObjectId, - GenericObjectId, - NewId, + -- ** Wayland types Fixed(..), WlString(..), - -- ** Classes for generated interfaces - IsInterface(InterfaceName), - interfaceName, - Side(..), - IsSide, - -- ** Protocol execution ProtocolHandle, initializeProtocol, @@ -25,19 +18,18 @@ module Quasar.Wayland.Protocol ( takeOutbox, setException, - -- ** Low-level protocol interaction (TODO should no longer be required after cleanup) - ProtocolM, - runProtocolTransaction, - runProtocolM, - newObject, - sendMessage, - -- * Protocol exceptions WireCallbackFailed(..), ParserFailed(..), ProtocolException(..), MaximumIdReached(..), ServerError(..), + + -- ** Classes for generated interfaces + IsInterface(InterfaceName), + interfaceName, + Side(..), + IsSide, ) where import Quasar.Wayland.Protocol.Core diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs index fa77c9ec01077840984261881205551073ceaff7..d38c423cc7a3595a0d29760ccd7a3f2d85afd1c3 100644 --- a/src/Quasar/Wayland/Protocol/Core.hs +++ b/src/Quasar/Wayland/Protocol/Core.hs @@ -16,7 +16,10 @@ module Quasar.Wayland.Protocol.Core ( interfaceName, IsInterfaceSide(..), IsInterfaceHandler(..), - Object, + Object(objectId), + getMessageHandler, + setMessageHandler, + NewObject, IsObject, IsMessage(..), ProtocolHandle, @@ -35,6 +38,7 @@ module Quasar.Wayland.Protocol.Core ( sendMessage, newObject, newObjectFromId, + getObject, -- * Protocol exceptions WireCallbackFailed(..), @@ -241,36 +245,46 @@ class IsInterfaceSide s i => IsInterfaceHandler s i a where data Side = Client | Server +-- | An object belonging to a wayland connection. data Object s i = IsInterfaceSide s i => Object { objectProtocol :: (ProtocolHandle s), - objectObjectId :: GenericObjectId, - messageHandler :: TMVar (MessageHandler s i) + objectId :: ObjectId (InterfaceName i), + messageHandler :: TVar (Maybe (MessageHandler s i)) } +getMessageHandler :: Object s i -> STM (MessageHandler s i) +getMessageHandler object = maybe retry pure =<< readTVar object.messageHandler + +setMessageHandler :: Object s i -> MessageHandler s i -> STM () +setMessageHandler object = writeTVar object.messageHandler . Just + +-- | Type alias to indicate an object is created with a message. +type NewObject s i = Object s i + instance IsInterface i => Show (Object s i) where show = showObject class IsObject a where - objectId :: a -> GenericObjectId + genericObjectId :: a -> GenericObjectId objectInterfaceName :: a -> String showObject :: a -> String - showObject object = objectInterfaceName object <> "@" <> show (objectId object) + showObject object = objectInterfaceName object <> "@" <> show (genericObjectId object) class IsObjectSide a where describeUpMessage :: a -> Opcode -> BSL.ByteString -> String describeDownMessage :: a -> Opcode -> BSL.ByteString -> String instance forall s i. IsInterface i => IsObject (Object s i) where - objectId = objectObjectId + genericObjectId object = toGenericObjectId object.objectId objectInterfaceName _ = interfaceName @i instance forall s i. IsInterfaceSide s i => IsObjectSide (Object s i) where describeUpMessage object opcode body = - objectInterfaceName object <> "@" <> show (objectId object) <> + objectInterfaceName object <> "@" <> show (genericObjectId object) <> "." <> fromMaybe "[invalidOpcode]" (opcodeName @(WireUp s i) opcode) <> " (" <> show (BSL.length body) <> "B)" describeDownMessage object opcode body = - objectInterfaceName object <> "@" <> show (objectId object) <> + objectInterfaceName object <> "@" <> show (genericObjectId object) <> "." <> fromMaybe "[invalidOpcode]" (opcodeName @(WireDown s i) opcode) <> " (" <> show (BSL.length body) <> "B)" @@ -280,8 +294,8 @@ data SomeObject s | UnknownObject String GenericObjectId instance IsObject (SomeObject s) where - objectId (SomeObject object) = objectId object - objectId (UnknownObject _ oId) = oId + genericObjectId (SomeObject object) = genericObjectId object + genericObjectId (UnknownObject _ oId) = oId objectInterfaceName (SomeObject object) = objectInterfaceName object objectInterfaceName (UnknownObject interface _) = interface @@ -308,7 +322,7 @@ instance IsMessage Void where invalidOpcode :: IsInterface i => Object s i -> Opcode -> Get a invalidOpcode object opcode = - fail $ "Invalid opcode " <> show opcode <> " on " <> objectInterfaceName object <> "@" <> show (objectId object) + fail $ "Invalid opcode " <> show opcode <> " on " <> objectInterfaceName object <> "@" <> show (genericObjectId object) showObjectMessage :: (IsObject a, IsMessage b) => a -> b -> String showObjectMessage object message = @@ -393,7 +407,7 @@ stateProtocolVar fn x = do initializeProtocol :: forall s wl_display a. (IsInterfaceSide s wl_display) => MessageHandler s wl_display - -> (Object s wl_display -> ProtocolM s a) + -> (Object s wl_display -> STM a) -> STM (a, ProtocolHandle s) initializeProtocol wlDisplayMessageHandler initializationAction = do bytesReceivedVar <- newTVar 0 @@ -423,15 +437,14 @@ initializeProtocol wlDisplayMessageHandler initializationAction = do } writeTVar stateVar (Right state) - messageHandlerVar <- newTMVar wlDisplayMessageHandler - let wlDisplay = Object protocol wlDisplayId messageHandlerVar - modifyTVar' objectsVar (HM.insert wlDisplayId (SomeObject wlDisplay)) + messageHandlerVar <- newTVar (Just wlDisplayMessageHandler) + let wlDisplay = Object protocol (ObjectId wlDisplayId) messageHandlerVar + modifyTVar' objectsVar (HM.insert (GenericObjectId wlDisplayId) (SomeObject wlDisplay)) - result <- runReaderT (initializationAction wlDisplay) state + result <- initializationAction wlDisplay pure (result, protocol) where - wlDisplayId :: GenericObjectId - wlDisplayId = GenericObjectId 1 + wlDisplayId = 1 -- | Run a protocol action in 'IO'. If an exception occurs, it is stored as a protocol failure and is then -- re-thrown. @@ -494,12 +507,12 @@ takeOutbox protocol = runProtocolTransaction protocol do -- Exported for use in TH generated code. newObject :: forall s i. IsInterfaceSide s i - => MessageHandler s i + => Maybe (MessageHandler s i) -> ProtocolM s (Object s i, NewId (InterfaceName i)) newObject messageHandler = do oId <- allocateObjectId let newId = NewId @(InterfaceName i) oId - object <- newObjectFromId newId messageHandler + object <- newObjectFromId messageHandler newId pure (object, newId) where allocateObjectId :: ProtocolM s (ObjectId (InterfaceName i)) @@ -512,26 +525,35 @@ newObject messageHandler = do writeProtocolVar (.nextIdVar) nextId' pure $ ObjectId id' + -- | Create an object from a received id. The caller is responsible for using a 'NewId' exactly once while handling an -- incoming message -- -- Exported for use in TH generated code. newObjectFromId :: forall s i. IsInterfaceSide s i - => NewId (InterfaceName i) - -> MessageHandler s i + => Maybe (MessageHandler s i) + -> NewId (InterfaceName i) -> ProtocolM s (Object s i) -newObjectFromId (NewId oId) messageHandler = do +newObjectFromId messageHandler (NewId oId) = do protocol <- askProtocol - messageHandlerVar <- lift $ newTMVar messageHandler + messageHandlerVar <- lift $ newTVar messageHandler let genericObjectId = toGenericObjectId oId - object = Object protocol genericObjectId messageHandlerVar + object = Object protocol oId messageHandlerVar someObject = SomeObject object modifyProtocolVar (.objectsVar) (HM.insert genericObjectId someObject) pure object +getObject + :: IsInterfaceSide s i + => ObjectId (InterfaceName i) + -> ProtocolM s (Object s i) +getObject = undefined + + + -- | Sends a message without checking any ids or creating proxy objects objects. (TODO) sendMessage :: forall s i. IsInterfaceSide s i => Object s i -> WireUp s i -> ProtocolM s () sendMessage object message = do @@ -549,8 +571,8 @@ sendMessage object message = do traceM $ "-> " <> showObjectMessage object message sendRawMessage $ putHeader opcode bodyLength >> putBody where - oId = objectId object - (GenericObjectId objectIdWord) = objectId object + oId = genericObjectId object + (GenericObjectId objectIdWord) = genericObjectId object putHeader :: Opcode -> Int -> Put putHeader opcode msgSize = do putWord32host objectIdWord @@ -594,7 +616,7 @@ handleRawMessage (oId, opcode, body) = do pure do message <- verifyMessage traceM $ "<- " <> showObjectMessage object message - messageHandler <- lift $ readTMVar object.messageHandler + messageHandler <- lift $ getMessageHandler object handleMessage @s @i messageHandler message type RawMessage = (GenericObjectId, Opcode, BSL.ByteString) diff --git a/src/Quasar/Wayland/Protocol/TH.hs b/src/Quasar/Wayland/Protocol/TH.hs index 362ee8650074af7d61a5fc17e20f87685d9e2345..f88d7d7bbec31fac4727030e09f41003e4d9cc58 100644 --- a/src/Quasar/Wayland/Protocol/TH.hs +++ b/src/Quasar/Wayland/Protocol/TH.hs @@ -46,6 +46,7 @@ data MessageSpec = MessageSpec { description :: Maybe DescriptionSpec, opcode :: Opcode, arguments :: [ArgumentSpec], + isConstructor :: Bool, isDestructor :: Bool } deriving stock Show @@ -179,16 +180,16 @@ interfaceDecs interface = do wireEventContexts = wireEventContext <$> interface.events requestCallbackRecordD :: Q Dec - requestCallbackRecordD = messageRecordD (requestsName interface) wireRequestContexts + requestCallbackRecordD = messageHandlerRecordD Server (requestsName interface) wireRequestContexts requestProxyInstanceDecs :: Q [Dec] - requestProxyInstanceDecs = messageProxyInstanceDecs [t|'Client|] wireRequestContexts + requestProxyInstanceDecs = messageProxyInstanceDecs Client wireRequestContexts eventCallbackRecordD :: Q Dec - eventCallbackRecordD = messageRecordD (eventsName interface) wireEventContexts + eventCallbackRecordD = messageHandlerRecordD Client (eventsName interface) wireEventContexts eventProxyInstanceDecs :: Q [Dec] - eventProxyInstanceDecs = messageProxyInstanceDecs [t|'Server|] wireEventContexts + eventProxyInstanceDecs = messageProxyInstanceDecs Server wireEventContexts handlerName = mkName "handler" handlerP = varP handlerName @@ -215,25 +216,77 @@ interfaceDecs interface = do msgHandlerE :: Q Exp msgHandlerE = [|$(appTypeE [|getField|] fieldNameLitT) $handlerE|] bodyE :: Q Exp - bodyE = [|lift $(applyMsgArgs msg msgHandlerE)|] + bodyE = [|lift =<< $(applyMsgArgs msg msgHandlerE)|] -messageProxyInstanceDecs :: Q Type -> [MessageContext] -> Q [Dec] -messageProxyInstanceDecs sideT messageContexts = mapM messageProxyInstanceD messageContexts + applyMsgArgs :: MessageContext -> Q Exp -> Q Exp + applyMsgArgs msg base = applyA base (argE <$> msg.msgSpec.arguments) + + argE :: ArgumentSpec -> Q Exp + argE arg = fromWireArgument arg.argType (msgArgE msg arg) + + fromWireArgument :: ArgumentType -> Q Exp -> Q Exp + fromWireArgument (ObjectArgument iName) objIdE = [|getObject $objIdE|] + fromWireArgument (NewIdArgument iName) objIdE = [|newObjectFromId Nothing $objIdE|] + fromWireArgument _ x = [|pure $x|] + +messageProxyInstanceDecs :: Side -> [MessageContext] -> Q [Dec] +messageProxyInstanceDecs side messageContexts = mapM messageProxyInstanceD messageContexts where messageProxyInstanceD :: MessageContext -> Q Dec messageProxyInstanceD msg = instanceD (pure []) instanceT [ - funD 'getField [clause ([varP objectName] <> msgArgPats msg) (normalB [|enterObject object (sendMessage object $(msgE msg))|]) []] + funD 'getField [clause ([varP objectName] <> msgProxyArgPats msg) (normalB [|enterObject object $actionE|]) []] ] where objectName = mkName "object" instanceT :: Q Type instanceT = [t|HasField $(litT (strTyLit msg.msgSpec.name)) $objectT $proxyT|] objectT :: Q Type - objectT = [t|Object $sideT $(msg.msgInterfaceT)|] + objectT = [t|Object $(sideT side) $(msg.msgInterfaceT)|] proxyT :: Q Type - proxyT = [t|$(applyArgTypes [t|STM ()|])|] + proxyT = [t|$(applyArgTypes [t|STM $returnT|])|] + returnT :: Q Type + returnT = maybe [t|()|] (argumentType side) (proxyReturnArgument msg.msgSpec) applyArgTypes :: Q Type -> Q Type - applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType <$> msg.msgSpec.arguments) + applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType side <$> args) + + args :: [ArgumentSpec] + args = proxyArguments msg.msgSpec + + actionE :: Q Exp + actionE = if msg.msgSpec.isConstructor then ctorE else normalE + + -- Constructor: the first argument becomes the return value + ctorE :: Q Exp + ctorE = [|newObject Nothing >>= \(newObject, newId) -> newObject <$ (sendMessage object =<< $(msgE [|pure newId|]))|] + where + msgE :: Q Exp -> Q Exp + msgE idArgE = mkWireMsgE (idArgE : (wireArgE <$> args)) + + -- Body for a normal (i.e. non-constructor) proxy + normalE :: Q Exp + normalE = [|sendMessage object =<< $(msgE)|] + where + msgE :: Q Exp + msgE = mkWireMsgE (wireArgE <$> args) + + mkWireMsgE :: [Q Exp] -> Q Exp + mkWireMsgE mkWireArgEs = applyA (conE msg.msgConName) mkWireArgEs + + wireArgE :: ArgumentSpec -> Q Exp + wireArgE arg = toWireArgument arg.argType (msgArgE msg arg) + + toWireArgument :: ArgumentType -> Q Exp -> Q Exp + -- TODO verify object validity + toWireArgument (ObjectArgument iName) objectE = [|pure $objectE.objectId|] + toWireArgument (NewIdArgument _) _ = impossibleCodePath -- The specification parser has a check to prevent this + toWireArgument _ x = [|pure $x|] + +proxyArguments :: MessageSpec -> [ArgumentSpec] +proxyArguments msg = (if msg.isConstructor then drop 1 else id) msg.arguments + +proxyReturnArgument :: MessageSpec -> Maybe ArgumentSpec +proxyReturnArgument msg@MessageSpec{arguments=(firstArg:_)} = if msg.isConstructor then Just firstArg else Nothing +proxyReturnArgument _ = Nothing messageFieldName :: MessageContext -> Name @@ -242,15 +295,15 @@ messageFieldName msg = mkName $ messageFieldNameString msg messageFieldNameString :: MessageContext -> String messageFieldNameString msg = msg.msgSpec.name -messageRecordD :: Name -> [MessageContext] -> Q Dec -messageRecordD name messageContexts = dataD (cxt []) name [] Nothing [con] [] +messageHandlerRecordD :: Side -> Name -> [MessageContext] -> Q Dec +messageHandlerRecordD side name messageContexts = dataD (cxt []) name [] Nothing [con] [] where con = recC name (recField <$> messageContexts) recField :: MessageContext -> Q VarBangType recField msg = varDefaultBangType (messageFieldName msg) [t|$(applyArgTypes [t|STM ()|])|] where applyArgTypes :: Q Type -> Q Type - applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType <$> msg.msgSpec.arguments) + applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType side <$> msg.msgSpec.arguments) sideTVarName :: Name @@ -258,6 +311,10 @@ sideTVarName = mkName "s" sideTVar :: Q Type sideTVar = varT sideTVarName +sideT :: Side -> Q Type +sideT Client = [t|'Client|] +sideT Server = [t|'Server|] + interfaceN :: InterfaceSpec -> Name interfaceN interface = mkName $ "Interface_" <> interface.name @@ -295,29 +352,34 @@ data MessageContext = MessageContext { msgSpec :: MessageSpec } --- | Pattern to match a message. Arguments can then be accessed by using 'msgArgE'. +-- | Pattern to match a wire message. Arguments can then be accessed by using 'msgArgE'. msgConP :: MessageContext -> Q Pat msgConP msg = conP msg.msgConName (msgArgPats msg) --- | Pattern to match all arguments of a message. Arguments can then be accessed by using e.g. 'msgArgE'. +-- | Pattern to match all arguments of a message (wire/handler). Arguments can then be accessed by using e.g. 'msgArgE'. msgArgPats :: MessageContext -> [Q Pat] msgArgPats msg = varP . msgArgTempName <$> msg.msgSpec.arguments +-- | Pattern to match all arguments of a message (for a proxy). Arguments can then be accessed by using e.g. 'msgArgE'. +msgProxyArgPats :: MessageContext -> [Q Pat] +msgProxyArgPats msg = varP . msgArgTempName <$> proxyArguments msg.msgSpec + -- | Expression for accessing a message argument which has been matched from a request/event using 'msgArgConP'. msgArgE :: MessageContext -> ArgumentSpec -> Q Exp msgArgE _msg arg = varE (msgArgTempName arg) -- | Helper for 'msgConP' and 'msgArgE'. msgArgTempName :: ArgumentSpec -> Name --- Add a prefix to prevent name conflicts with exports from the Prelude +-- Adds a prefix to prevent name conflicts with exports from the Prelude; would be better to use `newName` instead. msgArgTempName arg = mkName $ "arg_" <> arg.name -applyMsgArgs :: MessageContext -> Q Exp -> Q Exp -applyMsgArgs msg base = foldl appE base (msgArgE msg <$> msg.msgSpec.arguments) +applyWireMsgArgs :: MessageContext -> Q Exp -> Q Exp +applyWireMsgArgs msg base = foldl appE base (msgArgE msg <$> msg.msgSpec.arguments) -- | Expression to construct a wire message with arguments which have been matched using 'msgConP'/'msgArgPats'. -msgE :: MessageContext -> Q Exp -msgE msg = applyMsgArgs msg (conE msg.msgConName) +-- TODO Unused? +wireMsgE :: MessageContext -> Q Exp +wireMsgE msg = applyWireMsgArgs msg (conE msg.msgConName) messageTypeDecs :: Name -> [MessageContext] -> Q [Dec] @@ -340,12 +402,10 @@ messageTypeDecs name msgs = execWriterT do showD :: Q Dec showD = funD 'show (showClause <$> msgs) showClause :: MessageContext -> Q Clause - showClause msg = - clause - [msgConP msg] - (normalB [|mconcat $(listE ([stringE (msg.msgSpec.name ++ "(")] <> mconcat (intersperse [stringE ", "] (showArgE <$> msg.msgSpec.arguments) <> [[stringE ")"]])))|]) - [] + showClause msg = clause [msgConP msg] (normalB bodyE) [] where + bodyE :: Q Exp + bodyE = [|mconcat $(listE ([stringE (msg.msgSpec.name ++ "(")] <> mconcat (intersperse [stringE ", "] (showArgE <$> msg.msgSpec.arguments) <> [[stringE ")"]])))|] showArgE :: ArgumentSpec -> [Q Exp] showArgE arg = [stringE (arg.name ++ "="), [|showArgument @($(argumentWireType arg)) $(msgArgE msg arg)|]] @@ -389,12 +449,13 @@ derivingShow :: Q DerivClause derivingShow = derivClause (Just StockStrategy) [[t|Show|]] -- | Map an argument to its high-level api type -argumentType :: ArgumentSpec -> Q Type -argumentType argSpec = liftArgumentType argSpec.argType +argumentType :: Side -> ArgumentSpec -> Q Type +argumentType side argSpec = liftArgumentType side argSpec.argType -liftArgumentType :: ArgumentType -> Q Type ---liftArgumentType (ObjectArgument iName) = [t|Object $sideTVar $(interfaceTFromName iName)|] -liftArgumentType x = liftArgumentWireType x +liftArgumentType :: Side -> ArgumentType -> Q Type +liftArgumentType side (ObjectArgument iName) = [t|Object $(sideT side) $(interfaceTFromName iName)|] +liftArgumentType side (NewIdArgument iName) = [t|NewObject $(sideT side) $(interfaceTFromName iName)|] +liftArgumentType _ x = liftArgumentWireType x -- | Map an argument to its wire representation type @@ -528,10 +589,6 @@ parseMessage isRequest interface (opcode, element) = do do isEvent && isDestructor do fail $ "Event cannot be a destructor: " <> location - when - do (foldr (\arg -> if isNewId arg.argType then (+ 1) else id) 0 arguments) > (1 :: Int) - do fail $ "Message creates multiple objects: " <> location - forM_ arguments \arg -> do when do arg.argType == GenericNewIdArgument && (interface /= "wl_registry" || name /= "bind") @@ -540,12 +597,21 @@ parseMessage isRequest interface (opcode, element) = do do arg.argType == GenericObjectArgument && (interface /= "wl_display" || name /= "error") do fail $ "Invalid \"object\" argument without \"interface\" attribute encountered on " <> location <> " (only valid on wl_display.error)" + isConstructor <- case arguments of + [] -> pure False + (firstArg:otherArgs) -> do + when + do any (isNewId . (.argType)) otherArgs && not (interface == "wl_registry" && name == "bind") + do fail $ "Message uses NewId in unexpected position on: " <> location <> " (NewId must be the first argument, unless it is on wl_registry.bind)" + pure (isNewId firstArg.argType) + pure MessageSpec { name, since, description, opcode, arguments, + isConstructor, isDestructor } diff --git a/src/Quasar/Wayland/Registry.hs b/src/Quasar/Wayland/Registry.hs index 846e447b64643adac5fcee2c694a7df5d39b3a05..9b8b7982ea328f7b7c01782c286250f5dafbbb04 100644 --- a/src/Quasar/Wayland/Registry.hs +++ b/src/Quasar/Wayland/Registry.hs @@ -18,12 +18,12 @@ data ClientRegistry = ClientRegistry { globalsVar :: TVar (HM.HashMap Word32 (WlString, Word32)) } -createClientRegistry :: Object 'Client Interface_wl_display -> ProtocolM 'Client ClientRegistry +createClientRegistry :: Object 'Client Interface_wl_display -> STM ClientRegistry createClientRegistry wlDisplay = mfix \clientRegistry -> do - globalsVar <- lift $ newTVar HM.empty + globalsVar <- newTVar HM.empty - (wlRegistry, newId) <- newObject @'Client @Interface_wl_registry (messageHandler clientRegistry) - sendMessage wlDisplay $ WireRequest_wl_display__get_registry newId + wlRegistry <- wlDisplay.get_registry + setMessageHandler wlRegistry (messageHandler clientRegistry) pure ClientRegistry { wlRegistry,