From a2576bfd6c4c54e9ac3324d3e44bc24698ff5e4d Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Thu, 23 Dec 2021 20:29:50 +0100
Subject: [PATCH] Add support for nullable object id arguments

---
 src/Quasar/Wayland/Protocol/Core.hs | 16 +++++++++++++
 src/Quasar/Wayland/Protocol/TH.hs   | 36 ++++++++++++++++++-----------
 2 files changed, 38 insertions(+), 14 deletions(-)

diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs
index 761ad27..21edd76 100644
--- a/src/Quasar/Wayland/Protocol/Core.hs
+++ b/src/Quasar/Wayland/Protocol/Core.hs
@@ -40,12 +40,14 @@ module Quasar.Wayland.Protocol.Core (
 
   -- * Low-level protocol interaction
   objectWireArgument,
+  nullableObjectWireArgument,
   checkObject,
   sendMessage,
   newObject,
   newObjectFromId,
   bindNewObject,
   getObject,
+  getNullableObject,
   lookupObject,
   buildMessage,
 
@@ -643,6 +645,15 @@ getObject
   -> ProtocolM s (Object s i)
 getObject oId = either (throwM . ProtocolException . ("Received invalid object id: " <>)) pure =<< lookupObject oId
 
+-- | Lookup an object for an id or throw a `ProtocolException`. To be used from generated code when receiving an object
+-- id.
+getNullableObject
+  :: forall s i. IsInterfaceSide s i
+  => ObjectId (InterfaceName i)
+  -> ProtocolM s (Maybe (Object s i))
+getNullableObject (ObjectId 0) = pure Nothing
+getNullableObject oId = Just <$> getObject oId
+
 
 
 -- | Handle a wl_display.error message. Because this is part of the core protocol but generated from the xml it has to
@@ -675,6 +686,11 @@ objectWireArgument object = do
     Left msg -> throwM $ ProtocolUsageError $ "Tried to send a reference to an invalid object: " <> msg
     Right () -> pure object.objectId
 
+-- | Verify that an object can be used as an argument (throws otherwise) and return its id.
+nullableObjectWireArgument :: IsInterface i => Maybe (Object s i) -> ProtocolM s (ObjectId (InterfaceName i))
+nullableObjectWireArgument Nothing = pure (ObjectId 0)
+nullableObjectWireArgument (Just object) = objectWireArgument object
+
 
 -- | Sends a message, for use in generated code.
 sendMessage :: forall s i. IsInterfaceSide s i => Object s i -> WireUp s i -> ProtocolM s ()
diff --git a/src/Quasar/Wayland/Protocol/TH.hs b/src/Quasar/Wayland/Protocol/TH.hs
index 6b56de5..f5766ee 100644
--- a/src/Quasar/Wayland/Protocol/TH.hs
+++ b/src/Quasar/Wayland/Protocol/TH.hs
@@ -88,6 +88,7 @@ data ArgumentType
   | StringArgument
   | ArrayArgument
   | ObjectArgument String
+  | NullableObjectArgument String
   | GenericObjectArgument
   | NewIdArgument String
   | GenericNewIdArgument
@@ -257,6 +258,7 @@ interfaceDecs interface = do
 
         fromWireArgument :: ArgumentType -> Q Exp -> Q Exp
         fromWireArgument (ObjectArgument _) objIdE = [|getObject $objIdE|]
+        fromWireArgument (NullableObjectArgument _) objIdE = [|getNullableObject $objIdE|]
         fromWireArgument (NewIdArgument _) objIdE = [|newObjectFromId Nothing $objIdE|]
         fromWireArgument _ x = [|pure $x|]
 
@@ -308,6 +310,7 @@ messageProxyInstanceDecs side messageContexts = mapM messageProxyInstanceD messa
 
         toWireArgument :: ArgumentType -> Q Exp -> Q Exp
         toWireArgument (ObjectArgument _) objectE = [|objectWireArgument $objectE|]
+        toWireArgument (NullableObjectArgument _) objectE = [|nullableObjectWireArgument $objectE|]
         toWireArgument (NewIdArgument _) _ = unreachableCodePath -- The specification parser has a check to prevent this
         toWireArgument _ x = [|pure $x|]
 
@@ -464,6 +467,7 @@ argumentType side argSpec = liftArgumentType side argSpec.argType
 
 liftArgumentType :: Side -> ArgumentType -> Q Type
 liftArgumentType side (ObjectArgument iName) = [t|Object $(sideT side) $(interfaceTFromName iName)|]
+liftArgumentType side (NullableObjectArgument iName) = [t|Maybe (Object $(sideT side) $(interfaceTFromName iName))|]
 liftArgumentType side (NewIdArgument iName) = [t|NewObject $(sideT side) $(interfaceTFromName iName)|]
 liftArgumentType _ x = liftArgumentWireType x
 
@@ -479,6 +483,7 @@ liftArgumentWireType FixedArgument = [t|Fixed|]
 liftArgumentWireType StringArgument = [t|WlString|]
 liftArgumentWireType ArrayArgument = [t|BS.ByteString|]
 liftArgumentWireType (ObjectArgument iName) = [t|ObjectId $(litT (strTyLit iName))|]
+liftArgumentWireType (NullableObjectArgument iName) = [t|ObjectId $(litT (strTyLit iName))|]
 liftArgumentWireType GenericObjectArgument = [t|GenericObjectId|]
 liftArgumentWireType (NewIdArgument iName) = [t|NewId $(litT (strTyLit iName))|]
 liftArgumentWireType GenericNewIdArgument = [t|GenericNewId|]
@@ -631,7 +636,6 @@ parseArgument messageDescription (index, element) = do
   summary <- peekAttr "summary" element
   argTypeStr <- getAttr "type" element
   interface <- peekAttr "interface" element
-  argType <- parseArgumentType argTypeStr interface
 
   let loc = messageDescription <> "." <> name
 
@@ -640,6 +644,9 @@ parseArgument messageDescription (index, element) = do
     Just "false" -> pure False
     Just x -> fail $ "Invalid value for attribute \"allow-null\" on " <> loc <> ": " <> x
     Nothing -> pure False
+
+  argType <- parseArgumentType argTypeStr interface nullable
+
   pure ArgumentSpec {
     name,
     index,
@@ -648,19 +655,20 @@ parseArgument messageDescription (index, element) = do
     nullable
   }
   where
-    parseArgumentType :: String -> Maybe String -> m ArgumentType
-    parseArgumentType "int" Nothing = pure IntArgument
-    parseArgumentType "uint" Nothing = pure UIntArgument
-    parseArgumentType "fixed" Nothing = pure FixedArgument
-    parseArgumentType "string" Nothing = pure StringArgument
-    parseArgumentType "array" Nothing = pure ArrayArgument
-    parseArgumentType "object" (Just interface) = pure (ObjectArgument interface)
-    parseArgumentType "object" Nothing = pure GenericObjectArgument
-    parseArgumentType "new_id" (Just interface) = pure (NewIdArgument interface)
-    parseArgumentType "new_id" Nothing = pure GenericNewIdArgument
-    parseArgumentType "fd" Nothing = pure FdArgument
-    parseArgumentType x Nothing = fail $ "Unknown argument type \"" <> x <> "\" encountered"
-    parseArgumentType x _ = fail $ "Argument type \"" <> x <> "\" should not have \"interface\" attribute"
+    parseArgumentType :: String -> Maybe String -> Bool -> m ArgumentType
+    parseArgumentType "int" Nothing _ = pure IntArgument
+    parseArgumentType "uint" Nothing _ = pure UIntArgument
+    parseArgumentType "fixed" Nothing _ = pure FixedArgument
+    parseArgumentType "string" Nothing _ = pure StringArgument
+    parseArgumentType "array" Nothing _ = pure ArrayArgument
+    parseArgumentType "object" (Just interface) False = pure (ObjectArgument interface)
+    parseArgumentType "object" (Just interface) True = pure (NullableObjectArgument interface)
+    parseArgumentType "object" Nothing _ = pure GenericObjectArgument
+    parseArgumentType "new_id" (Just interface) _ = pure (NewIdArgument interface)
+    parseArgumentType "new_id" Nothing _ = pure GenericNewIdArgument
+    parseArgumentType "fd" Nothing _ = pure FdArgument
+    parseArgumentType x Nothing _ = fail $ "Unknown argument type \"" <> x <> "\" encountered"
+    parseArgumentType x _ _ = fail $ "Argument type \"" <> x <> "\" should not have \"interface\" attribute"
 
 
 parseEnum :: MonadFail m => Element -> m EnumSpec
-- 
GitLab