From 5ac56da09288dfa40c830d4841b1620d624b7885 Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Wed, 15 Sep 2021 22:37:45 +0200
Subject: [PATCH] Add api to create objects; make ids typesafe; remove
 initialization hack

---
 src/Quasar/Wayland/Client.hs        |  47 ++++++-----
 src/Quasar/Wayland/Connection.hs    |  20 ++---
 src/Quasar/Wayland/Protocol/Core.hs | 120 ++++++++++++++++++++--------
 src/Quasar/Wayland/Protocol/TH.hs   |  46 +++++++----
 4 files changed, 147 insertions(+), 86 deletions(-)

diff --git a/src/Quasar/Wayland/Client.hs b/src/Quasar/Wayland/Client.hs
index 65e9d5f..658a2ac 100644
--- a/src/Quasar/Wayland/Client.hs
+++ b/src/Quasar/Wayland/Client.hs
@@ -1,6 +1,7 @@
 module Quasar.Wayland.Client (
   connectWaylandClient,
   newWaylandClient,
+  connectWaylandSocket,
 ) where
 
 import Control.Concurrent.STM
@@ -22,42 +23,40 @@ import System.FilePath ((</>), isRelative)
 import Text.Read (readEither)
 
 
-data WaylandClient = WaylandClient (WaylandConnection 'Client)
+data WaylandClient = WaylandClient (WaylandConnection 'Client) (Object 'Client STM I_wl_display)
 
 instance IsResourceManager WaylandClient where
-  toResourceManager (WaylandClient connection) = toResourceManager connection
+  toResourceManager (WaylandClient connection _) = toResourceManager connection
 
 instance IsDisposable WaylandClient where
-  toDisposable (WaylandClient connection) = toDisposable connection
+  toDisposable (WaylandClient connection _) = toDisposable connection
 
 newWaylandClient :: MonadResourceManager m => Socket -> m WaylandClient
-newWaylandClient socket = WaylandClient <$>
-  newWaylandConnection
-    @I_wl_display
-    @I_wl_registry
-    (traceCallback ignoreMessage)
-    -- HACK to send get_registry
-    (Just (R_wl_display_get_registry (NewId 2)))
-    (traceCallback ignoreMessage)
-    socket
+newWaylandClient socket = do
+  (connection, wlDisplay) <- newWaylandConnection @I_wl_display (traceCallback ignoreMessage) socket
+
+  (wlRegistry, newId) <- stepProtocol connection $ newObject @'Client @STM @I_wl_registry (traceCallback ignoreMessage)
+  stepProtocol connection $ sendMessage wlDisplay $ R_wl_display_get_registry newId
+  pure $ WaylandClient connection wlDisplay
 
 connectWaylandClient :: MonadResourceManager m => m WaylandClient
 connectWaylandClient = mask_ do
   socket <- liftIO connectWaylandSocket
   newWaylandClient socket
-  where
-    connectWaylandSocket :: IO Socket
-    connectWaylandSocket = do
-      lookupEnv "WAYLAND_SOCKET" >>= \case
-        -- Parent process already established connection
-        Just waylandSocketEnv -> do
-          case readEither waylandSocketEnv of
-            Left err -> fail $ "Failed to parse WAYLAND_SOCKET: " <> err
-            Right fd -> Socket.mkSocket fd
-        Nothing -> do
-          path <- getWaylandSocketPath
-          newUnixSocket path
 
+connectWaylandSocket :: IO Socket
+connectWaylandSocket = do
+  lookupEnv "WAYLAND_SOCKET" >>= \case
+    -- Parent process already established connection
+    Just waylandSocketEnv -> do
+      case readEither waylandSocketEnv of
+        Left err -> fail $ "Failed to parse WAYLAND_SOCKET: " <> err
+        Right fd -> Socket.mkSocket fd
+    Nothing -> do
+      path <- getWaylandSocketPath
+      newUnixSocket path
+
+  where
     getWaylandSocketPath :: IO FilePath
     getWaylandSocketPath = do
       waylandDisplayEnv <- lookupEnv "WAYLAND_DISPLAY"
diff --git a/src/Quasar/Wayland/Connection.hs b/src/Quasar/Wayland/Connection.hs
index 8efeed6..f5b0b99 100644
--- a/src/Quasar/Wayland/Connection.hs
+++ b/src/Quasar/Wayland/Connection.hs
@@ -1,6 +1,7 @@
 module Quasar.Wayland.Connection (
   WaylandConnection,
   newWaylandConnection,
+  stepProtocol,
 ) where
 
 import Control.Concurrent.STM
@@ -35,13 +36,11 @@ data SocketClosed = SocketClosed
   deriving anyclass Exception
 
 newWaylandConnection
-  :: forall wl_display wl_registry s m. (IsInterfaceSide s wl_display, IsInterfaceSide s wl_registry, MonadResourceManager m)
+  :: forall wl_display wl_registry s m. (IsInterfaceSide s wl_display, MonadResourceManager m)
   => Callback s STM wl_display
-  -> Maybe (Up s wl_display)
-  -> Callback s STM wl_registry
   -> Socket
-  -> m (WaylandConnection s)
-newWaylandConnection wlDisplayCallback initializationMessage wlRegistryCallback socket = do
+  -> m (WaylandConnection s, Object s STM wl_display)
+newWaylandConnection wlDisplayCallback socket = do
   protocolStateVar <- liftIO $ newTVarIO protocolState
   outboxVar <- liftIO newEmptyTMVarIO
 
@@ -61,16 +60,9 @@ newWaylandConnection wlDisplayCallback initializationMessage wlRegistryCallback
       connectionThread connection $ sendThread connection
       connectionThread connection $ receiveThread connection
 
-    -- Create registry, if requested
-    forM_ initializationMessage \msg ->
-      sendProtocolMessage connection wlDisplay msg
-
-    pure connection
+    pure (connection, wlDisplay)
   where
-    (protocolState, wlDisplay) = initialProtocolState wlDisplayCallback wlRegistryCallback
-
-sendProtocolMessage :: (IsInterfaceSide s i, MonadIO m) => WaylandConnection s -> Object s STM i -> Up s i -> m ()
-sendProtocolMessage connection object message = stepProtocol connection $ sendMessage object message
+    (protocolState, wlDisplay) = initialProtocolState wlDisplayCallback
 
 stepProtocol :: forall s m a. MonadIO m => WaylandConnection s -> ProtocolStep s STM a -> m a
 stepProtocol connection step = liftIO do
diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs
index 6616854..3ab1775 100644
--- a/src/Quasar/Wayland/Protocol/Core.hs
+++ b/src/Quasar/Wayland/Protocol/Core.hs
@@ -2,6 +2,7 @@
 
 module Quasar.Wayland.Protocol.Core (
   ObjectId,
+  GenericObjectId,
   NewId(..),
   Opcode,
   ArgumentType(..),
@@ -22,6 +23,7 @@ module Quasar.Wayland.Protocol.Core (
   ProtocolStep,
   initialProtocolState,
   sendMessage,
+  newObject,
   feedInput,
   setException,
 
@@ -47,12 +49,18 @@ import Data.ByteString.Lazy qualified as BSL
 import Data.HashMap.Strict (HashMap)
 import Data.HashMap.Strict qualified as HM
 import Data.Maybe (isJust)
+import Data.Proxy
 import Data.Void (absurd)
+import GHC.TypeLits
 import Language.Haskell.TH.Syntax (Lift)
 import Quasar.Prelude
 
 
-type ObjectId = Word32
+newtype ObjectId (j :: Symbol) = ObjectId GenericObjectId
+  deriving stock (Eq, Show)
+
+type GenericObjectId = Word32
+
 type Opcode = Word16
 
 -- | Signed 24.8 decimal numbers.
@@ -62,15 +70,17 @@ newtype Fixed = Fixed Word32
 instance Show Fixed where
   show x = "[fixed " <> show x <> "]"
 
-newtype NewId = NewId ObjectId
-  deriving newtype (Eq, Show)
+newtype NewId (j :: Symbol) = NewId GenericObjectId
+  deriving stock (Eq, Show)
+
+newtype GenericNewId = GenericNewId GenericObjectId
+  deriving stock (Eq, Show)
 
 
 dropRemaining :: Get ()
 dropRemaining = void getRemainingLazyByteString
 
 
-
 data ArgumentType
   = IntArgument
   | UIntArgument
@@ -78,9 +88,9 @@ data ArgumentType
   | StringArgument
   | ArrayArgument
   | ObjectArgument String
-  | UnknownObjectArgument
+  | GenericObjectArgument
   | NewIdArgument String
-  | UnknownNewIdArgument
+  | GenericNewIdArgument
   | FdArgument
   deriving stock (Eq, Show, Lift)
 
@@ -120,29 +130,29 @@ instance WireFormat 'ArrayArgument where
   getArgument = getWaylandBlob
   showArgument array = "[array " <> show (BS.length array) <> "B]"
 
-instance WireFormat 'ObjectArgument where
-  type Argument 'ObjectArgument = ObjectId
-  putArgument = putWord32host
-  getArgument = getWord32host
-  showArgument oId = "@" <> show oId
+instance KnownSymbol j => WireFormat (ObjectId (j :: Symbol)) where
+  type Argument (ObjectId j) = ObjectId j
+  putArgument (ObjectId oId) = putWord32host oId
+  getArgument = ObjectId <$> getWord32host
+  showArgument (ObjectId oId) = symbolVal @j Proxy <> "@" <> show oId
 
-instance WireFormat 'UnknownObjectArgument where
-  type Argument 'UnknownObjectArgument = ObjectId
+instance WireFormat 'GenericObjectArgument where
+  type Argument 'GenericObjectArgument = GenericObjectId
   putArgument = putWord32host
   getArgument = getWord32host
-  showArgument oId = "@" <> show oId
+  showArgument oId = "[unknown]@" <> show oId
 
-instance WireFormat 'NewIdArgument where
-  type Argument 'NewIdArgument = NewId
+instance KnownSymbol j => WireFormat (NewId (j :: Symbol)) where
+  type Argument (NewId j) = NewId j
   putArgument (NewId newId) = putWord32host newId
   getArgument = NewId <$> getWord32host
-  showArgument newId = "new @" <> show newId
+  showArgument (NewId newId) = "new " <> symbolVal @j Proxy <> "@" <> show newId
 
-instance WireFormat 'UnknownNewIdArgument where
-  type Argument 'UnknownNewIdArgument = NewId
-  putArgument (NewId newId) = putWord32host newId
-  getArgument = NewId <$> getWord32host
-  showArgument newId = "new @" <> show newId
+instance WireFormat 'GenericNewIdArgument where
+  type Argument 'GenericNewIdArgument = GenericNewId
+  putArgument (GenericNewId newId) = putWord32host newId
+  getArgument = GenericNewId <$> getWord32host
+  showArgument newId = "new [unknown]@" <> show newId
 
 instance WireFormat 'FdArgument where
   type Argument 'FdArgument = Void
@@ -159,19 +169,27 @@ class (
   => IsInterface i where
   type Request i
   type Event i
+  type InterfaceName i :: Symbol
   interfaceName :: String
 
 class IsSide (s :: Side) where
   type Up s i
   type Down s i
+  initialId :: GenericObjectId
+  maximumId :: GenericObjectId
 
 instance IsSide 'Client where
   type Up 'Client i = Request i
   type Down 'Client i = Event i
+  -- Id #1 is reserved for wl_display
+  initialId = 2
+  maximumId = 0xfeffffff
 
 instance IsSide 'Server where
   type Up 'Server i = Event i
   type Down 'Server i = Request i
+  initialId = 0xff000000
+  maximumId = 0xffffffff
 
 
 --- | Empty class, used to combine constraints
@@ -198,10 +216,10 @@ class IsInterfaceSide s i => IsInterfaceHandler s m i a where
 -- | Data kind
 data Side = Client | Server
 
-data Object s m i = IsInterfaceSide s i => Object ObjectId (Callback s m i)
+data Object s m i = IsInterfaceSide s i => Object GenericObjectId (Callback s m i)
 
 class IsObject a where
-  objectId :: a -> ObjectId
+  objectId :: a -> GenericObjectId
   objectInterfaceName :: a -> String
 
 class IsObjectSide a where
@@ -225,7 +243,7 @@ instance forall s m i. IsInterfaceSide s i => IsObjectSide (Object s m i) where
 -- | Wayland object quantification wrapper
 data SomeObject s m
   = forall i. IsInterfaceSide s i => SomeObject (Object s m i)
-  | UnknownObject String ObjectId
+  | UnknownObject String GenericObjectId
 
 instance IsObject (SomeObject s m) where
   objectId (SomeObject object) = objectId object
@@ -269,7 +287,8 @@ data ProtocolState (s :: Side) m = ProtocolState {
   bytesSent :: !Int64,
   inboxDecoder :: Decoder RawMessage,
   outbox :: Maybe Put,
-  objects :: HashMap ObjectId (SomeObject s m)
+  objects :: HashMap GenericObjectId (SomeObject s m),
+  nextId :: Word32
 }
 
 
@@ -320,10 +339,15 @@ data ProtocolException = ProtocolException String
   deriving stock Show
   deriving anyclass Exception
 
+data MaximumIdReached = MaximumIdReached
+  deriving stock Show
+  deriving anyclass Exception
+
 -- * Monad plumbing
 
 type ProtocolStep s m a = ProtocolState s m -> m (Either SomeException a, Maybe BSL.ByteString, ProtocolState s m)
 
+-- Must not be exported. 'ProtocolStep' ensures proper protocol failure in case of exceptions.
 type ProtocolAction s m a = StateT (ProtocolState s m) m a
 
 protocolStep :: forall s m a. MonadCatch m => ProtocolAction s m a -> ProtocolStep s m a
@@ -345,16 +369,13 @@ protocolStep action inState = do
 -- * Exported functions
 
 initialProtocolState
-  :: forall wl_display wl_registry s m. (IsInterfaceSide s wl_display, IsInterfaceSide s wl_registry)
+  :: forall wl_display wl_registry s m. IsInterfaceSide s wl_display
   => Callback s m wl_display
-  -> Callback s m wl_registry
   -> (ProtocolState s m, Object s m wl_display)
-initialProtocolState wlDisplayCallback wlRegistryCallback = (initialState, wlDisplay)
+initialProtocolState wlDisplayCallback = (initialState, wlDisplay)
   where
     wlDisplay :: Object s m wl_display
     wlDisplay = Object 1 wlDisplayCallback
-    wlRegistry :: Object s m wl_registry
-    wlRegistry = Object 2 wlRegistryCallback
     initialState :: ProtocolState s m
     initialState = ProtocolState {
       protocolException = Nothing,
@@ -362,7 +383,8 @@ initialProtocolState wlDisplayCallback wlRegistryCallback = (initialState, wlDis
       bytesSent = 0,
       inboxDecoder = runGetIncremental getRawMessage,
       outbox = Nothing,
-      objects = HM.fromList [(1, (SomeObject wlDisplay)), (2, (SomeObject wlRegistry))]
+      objects = HM.fromList [(1, (SomeObject wlDisplay))],
+      nextId = initialId @s
     }
 
 -- | Feed the protocol newly received data
@@ -380,6 +402,38 @@ setException :: (MonadCatch m, Exception e) => e -> ProtocolStep s m ()
 setException ex = protocolStep do
   State.modify \st -> st{protocolException = Just (toException ex)}
 
+
+-- Create an object. The caller is responsible for sending the 'NewId' exactly once before using the object.
+newObject
+  :: forall s m i. (IsInterfaceSide s i, MonadCatch m)
+  => Callback s m i
+  -> ProtocolStep s m (Object s m i, NewId (InterfaceName i))
+newObject callback = protocolStep $ newObjectInternal callback
+
+newObjectInternal
+  :: forall s m i. (IsInterfaceSide s i, MonadCatch m)
+  => Callback s m i
+  -> ProtocolAction s m (Object s m i, NewId (InterfaceName i))
+newObjectInternal callback = do
+  oId <- allocateObjectId @s @m @i
+  let
+    object = Object oId callback
+    someObject = SomeObject object
+  State.modify \st -> st { objects = HM.insert oId someObject st.objects}
+  pure (object, NewId oId)
+  where
+    allocateObjectId :: forall s m i. (IsInterfaceSide s i, MonadCatch m) => ProtocolAction s m GenericObjectId
+    allocateObjectId = do
+      st <- State.get
+      let
+        id = st.nextId
+        nextId' = id + 1
+
+      when (nextId' == maximumId @s) $ throwM MaximumIdReached
+      State.put $ st {nextId = nextId'}
+      pure id
+
+
 -- | Sends a message without checking any ids or creating proxy objects objects.
 sendMessage :: forall s m i. (IsInterfaceSide s i, MonadCatch m) => Object s m i -> Up s i -> ProtocolStep s m ()
 sendMessage object message = protocolStep $ sendMessageInternal object message
@@ -445,7 +499,7 @@ getMessageAction object@(Object _ objectHandler) opcode = do
   message <- getDown object opcode
   pure $ handleMessage objectHandler object message
 
-type RawMessage = (ObjectId, Opcode, BSL.ByteString)
+type RawMessage = (GenericObjectId, Opcode, BSL.ByteString)
 
 receiveRawMessage :: forall s m. MonadCatch m => ProtocolAction s m (Maybe RawMessage)
 receiveRawMessage = do
diff --git a/src/Quasar/Wayland/Protocol/TH.hs b/src/Quasar/Wayland/Protocol/TH.hs
index 8be2a10..cb94e96 100644
--- a/src/Quasar/Wayland/Protocol/TH.hs
+++ b/src/Quasar/Wayland/Protocol/TH.hs
@@ -89,9 +89,10 @@ interfaceDecs interface = do
     iName = interfaceN interface
     iT = interfaceT interface
     instanceDecs = [
-      valD (varP 'interfaceName) (normalB (stringE interface.name)) [],
       tySynInstD (tySynEqn Nothing (appT (conT ''Request) iT) rT),
-      tySynInstD (tySynEqn Nothing (appT (conT ''Event) iT) eT)
+      tySynInstD (tySynEqn Nothing (appT (conT ''Event) iT) eT),
+      tySynInstD (tySynEqn Nothing (appT (conT ''InterfaceName) iT) (litT (strTyLit interface.name))),
+      valD (varP 'interfaceName) (normalB (stringE interface.name)) []
       ]
     rT :: Q Type
     rT = if length interface.requests > 0 then conT rTypeName else [t|Void|]
@@ -149,6 +150,15 @@ interfaceDecs interface = do
         applyArgTypes :: Q Type -> Q Type
         applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType <$> msg.msgSpec.arguments)
 
+interfaceSideInstanceDs :: InterfaceSpec -> Q [Dec]
+interfaceSideInstanceDs interface = execWriterT do
+  tellQs [d|instance IsInterfaceSide 'Client $iT|]
+  tellQs [d|instance IsInterfaceSide 'Server $iT|]
+  --tellQs [d|instance forall m a. IsInterfaceHandler 'Client m $iT a where {handleMessage = undefined}|]
+  --tellQs [d|instance forall m a. IsInterfaceHandler 'Server m $iT a where {handleMessage = undefined}|]
+  where
+    iT = interfaceT interface
+
 
 interfaceN :: InterfaceSpec -> Name
 interfaceN interface = mkName $ "I_" <> interface.name
@@ -272,12 +282,14 @@ argumentSpecType :: ArgumentSpec -> Q Type
 argumentSpecType argSpec = promoteArgumentSpecType argSpec.argType
 
 promoteArgumentSpecType :: ArgumentType -> Q Type
+promoteArgumentSpecType (ObjectArgument iName) = [t|ObjectId $(litT $ strTyLit iName)|]
+promoteArgumentSpecType (NewIdArgument iName) = [t|NewId $(litT $ strTyLit iName)|]
 promoteArgumentSpecType arg = do
   argExp <- (TH.lift arg)
-  ConT <$> matchCon argExp
+  matchCon argExp
   where
-    matchCon :: Exp -> Q Name
-    matchCon (ConE name) = pure name
+    matchCon :: Exp -> Q Type
+    matchCon (ConE name) = pure $ ConT name
     matchCon (AppE x _) = matchCon x
     matchCon _ = fail "Can only promote ConE expression"
 
@@ -310,8 +322,8 @@ parseInterface :: MonadFail m => Element -> m InterfaceSpec
 parseInterface element = do
   name <- getAttr "name" element
   version <- read <$> getAttr "version" element
-  requests <- mapM parseRequest $ zip [0..] $ findChildren (qname "request") element
-  events <- mapM parseEvent $ zip [0..] $ findChildren (qname "event") element
+  requests <- mapM (parseRequest name) $ zip [0..] $ findChildren (qname "request") element
+  events <- mapM (parseEvent name) $ zip [0..] $ findChildren (qname "event") element
   pure InterfaceSpec {
     name,
     version,
@@ -319,17 +331,21 @@ parseInterface element = do
     events
   }
 
-parseRequest :: MonadFail m => (Opcode, Element) -> m RequestSpec
-parseRequest x = RequestSpec <$> parseMessage x
+parseRequest :: MonadFail m => String -> (Opcode, Element) -> m RequestSpec
+parseRequest x y = RequestSpec <$> parseMessage x y
 
-parseEvent :: MonadFail m => (Opcode, Element) -> m EventSpec
-parseEvent x = EventSpec <$> parseMessage x
+parseEvent :: MonadFail m => String -> (Opcode, Element) -> m EventSpec
+parseEvent x y = EventSpec <$> parseMessage x y
 
-parseMessage :: MonadFail m => (Opcode, Element) -> m MessageSpec
-parseMessage (opcode, element) = do
+parseMessage :: MonadFail m => String -> (Opcode, Element) -> m MessageSpec
+parseMessage interfaceName (opcode, element) = do
   name <- getAttr "name" element
   since <- read <<$>> peekAttr "since" element
   arguments <- mapM parseArgument $ zip [0..] $ findChildren (qname "arg") element
+  forM_ arguments \arg ->
+    when
+      do arg.argType == GenericNewIdArgument && interfaceName /= "wl_registry"
+      do fail $ "Invalid GenericNewIdArgument encountered on " <> interfaceName <> "." <> name <> " (only valid on wl_registry)"
   pure MessageSpec  {
     name,
     since,
@@ -357,9 +373,9 @@ parseArgument (index, element) = do
     parseArgumentType "string" Nothing = pure StringArgument
     parseArgumentType "array" Nothing = pure ArrayArgument
     parseArgumentType "object" (Just interface) = pure (ObjectArgument interface)
-    parseArgumentType "object" Nothing = pure UnknownObjectArgument
+    parseArgumentType "object" Nothing = pure GenericObjectArgument
     parseArgumentType "new_id" (Just interface) = pure (NewIdArgument interface)
-    parseArgumentType "new_id" Nothing = pure UnknownNewIdArgument
+    parseArgumentType "new_id" Nothing = pure GenericNewIdArgument
     parseArgumentType "fd" Nothing = pure FdArgument
     parseArgumentType x Nothing = fail $ "Unknown argument type \"" <> x <> "\" encountered"
     parseArgumentType x _ = fail $ "Argument type \"" <> x <> "\" should not have \"interface\" attribute"
-- 
GitLab