From 70d4f07b5412403fe2f06a4182012b54bb40eec9 Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Tue, 21 Sep 2021 16:16:52 +0200
Subject: [PATCH] Run message de-/encoder in ProtocolM monad

Experimental way to enable early object id checks, might be removed once
the higher-level api generater can provide the checks.
---
 src/Quasar/Wayland/Protocol/Core.hs    | 115 ++++++++++++++-----------
 src/Quasar/Wayland/Protocol/Display.hs |   7 +-
 src/Quasar/Wayland/Protocol/TH.hs      |  24 ++++--
 src/Quasar/Wayland/Registry.hs         |  10 +--
 4 files changed, 88 insertions(+), 68 deletions(-)

diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs
index abd711c..cb0dd25 100644
--- a/src/Quasar/Wayland/Protocol/Core.hs
+++ b/src/Quasar/Wayland/Protocol/Core.hs
@@ -59,6 +59,7 @@ import Control.Concurrent.STM
 import Control.Monad (replicateM_)
 import Control.Monad.Catch
 import Control.Monad.Reader (ReaderT, runReaderT, ask, lift)
+import Data.Bifunctor qualified as Bifunctor
 import Data.Binary
 import Data.Binary.Get
 import Data.Binary.Put
@@ -104,6 +105,7 @@ instance Show Fixed where
 --
 -- Instances and functions in this library assume UTF-8, but the original data is also available by deconstructing.
 newtype WlString = WlString BS.ByteString
+  deriving newtype (Eq, Hashable)
 
 instance Show WlString where
   show = show . toString
@@ -139,62 +141,62 @@ isNewId _ = False
 
 class (Eq (Argument a), Show (Argument a)) => WireFormat a where
   type Argument a
-  putArgument :: Argument a -> PutM ()
-  getArgument :: Get (Argument a)
+  putArgument :: Argument a -> ProtocolM s (Put, Int)
+  getArgument :: Get (ProtocolM s (Argument a))
   showArgument :: Argument a -> String
 
 instance WireFormat 'IntArgument where
   type Argument 'IntArgument = Int32
-  putArgument = putInt32host
-  getArgument = getInt32host
+  putArgument x = pure (putInt32host x, 4)
+  getArgument = pure <$> getInt32host
   showArgument = show
 
 instance WireFormat 'UIntArgument where
   type Argument 'UIntArgument = Word32
-  putArgument = putWord32host
-  getArgument = getWord32host
+  putArgument x = pure (putWord32host x, 4)
+  getArgument = pure <$> getWord32host
   showArgument = show
 
 instance WireFormat 'FixedArgument where
   type Argument 'FixedArgument = Fixed
-  putArgument (Fixed repr) = putWord32host repr
-  getArgument = Fixed <$> getWord32host
+  putArgument (Fixed repr) = pure (putWord32host repr, 4)
+  getArgument = pure . Fixed <$> getWord32host
   showArgument = show
 
 instance WireFormat 'StringArgument where
-  type Argument 'StringArgument = BS.ByteString
-  putArgument = putWaylandBlob
-  getArgument = getWaylandBlob
+  type Argument 'StringArgument = WlString
+  putArgument (WlString x) = pure $ putWaylandBlob x
+  getArgument = pure . WlString <$> getWaylandBlob
   showArgument = show
 
 instance WireFormat 'ArrayArgument where
   type Argument 'ArrayArgument = BS.ByteString
-  putArgument = putWaylandBlob
-  getArgument = getWaylandBlob
+  putArgument x = pure $ putWaylandBlob x
+  getArgument = pure <$> getWaylandBlob
   showArgument array = "[array " <> show (BS.length array) <> "B]"
 
 instance KnownSymbol j => WireFormat (ObjectId (j :: Symbol)) where
   type Argument (ObjectId j) = ObjectId j
-  putArgument (ObjectId oId) = putWord32host oId
-  getArgument = ObjectId <$> getWord32host
+  putArgument (ObjectId oId) = pure (putWord32host oId, 4)
+  getArgument = pure . ObjectId <$> getWord32host
   showArgument (ObjectId oId) = symbolVal @j Proxy <> "@" <> show oId
 
 instance WireFormat 'GenericObjectArgument where
   type Argument 'GenericObjectArgument = GenericObjectId
-  putArgument = putWord32host
-  getArgument = getWord32host
+  putArgument x = pure (putWord32host x, 4)
+  getArgument = pure <$> getWord32host
   showArgument oId = "[unknown]@" <> show oId
 
 instance KnownSymbol j => WireFormat (NewId (j :: Symbol)) where
   type Argument (NewId j) = NewId j
-  putArgument (NewId newId) = putWord32host newId
-  getArgument = NewId <$> getWord32host
+  putArgument (NewId newId) = pure (putWord32host newId, 4)
+  getArgument = pure . NewId <$> getWord32host
   showArgument (NewId newId) = "new " <> symbolVal @j Proxy <> "@" <> show newId
 
 instance WireFormat 'GenericNewIdArgument where
   type Argument 'GenericNewIdArgument = GenericNewId
-  putArgument (GenericNewId newId) = putWord32host newId
-  getArgument = GenericNewId <$> getWord32host
+  putArgument (GenericNewId newId) = pure (putWord32host newId, 4)
+  getArgument = pure . GenericNewId <$> getWord32host
   showArgument newId = "new [unknown]@" <> show newId
 
 instance WireFormat 'FdArgument where
@@ -245,10 +247,10 @@ class (
   => IsInterfaceSide (s :: Side) i
 
 
-getDown :: forall s i. IsInterfaceSide s i => Object s i -> Opcode -> Get (Down s i)
+getDown :: forall s i. IsInterfaceSide s i => Object s i -> Opcode -> Get (ProtocolM s (Down s i))
 getDown = getMessage @(Down s i)
 
-putUp :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> PutM Opcode
+putUp :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolM s (Opcode, [(Put, Int)])
 putUp _ = putMessage @(Up s i)
 
 
@@ -307,8 +309,8 @@ instance IsObjectSide (SomeObject s) where
 
 class (Eq a, Show a) => IsMessage a where
   opcodeName :: Opcode -> Maybe String
-  getMessage :: IsInterface i => Object s i -> Opcode -> Get a
-  putMessage :: a -> PutM Opcode
+  getMessage :: IsInterface i => Object s i -> Opcode -> Get (ProtocolM s a)
+  putMessage :: a -> ProtocolM s (Opcode, [(Put, Int)])
 
 instance IsMessage Void where
   opcodeName _ = Nothing
@@ -413,6 +415,11 @@ modifyProtocolVar fn x = do
   state <- ask
   lift $ modifyTVar (fn state) x
 
+modifyProtocolVar' :: (ProtocolState s -> TVar a) -> (a -> a) -> ProtocolM s ()
+modifyProtocolVar' fn x = do
+  state <- ask
+  lift $ modifyTVar' (fn state) x
+
 stateProtocolVar :: (ProtocolState s -> TVar a) -> (a -> (r, a)) -> ProtocolM s r
 stateProtocolVar fn x = do
   state <- ask
@@ -484,7 +491,7 @@ runProtocolM protocol action = either throwM (runReaderT action) =<< readTVar pr
 feedInput :: (IsSide s, MonadIO m, MonadThrow m) => ProtocolHandle s -> ByteString -> m ()
 feedInput protocol bytes = runProtocolTransaction protocol do
   -- Exposing MonadIO instead of STM to the outside and using `runProtocolTransaction` here enforces correct exception handling.
-  modifyProtocolVar (.bytesReceivedVar) (+ fromIntegral (BS.length bytes))
+  modifyProtocolVar' (.bytesReceivedVar) (+ fromIntegral (BS.length bytes))
   modifyProtocolVar (.inboxDecoderVar) (`pushChunk` bytes)
   receiveMessages
 
@@ -498,7 +505,7 @@ takeOutbox protocol = runProtocolTransaction protocol do
   mOutboxData <- stateProtocolVar (.outboxVar) (\mOutboxData -> (mOutboxData, Nothing))
   outboxData <- maybe (lift retry) pure mOutboxData
   let sendData = runPut outboxData
-  modifyProtocolVar (.bytesSentVar) (+ BSL.length sendData)
+  modifyProtocolVar' (.bytesSentVar) (+ BSL.length sendData)
   pure sendData
 
 
@@ -540,25 +547,28 @@ newObjectFromId (NewId oId) callback = do
 -- | Sends a message without checking any ids or creating proxy objects objects. (TODO)
 sendMessage :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolM s ()
 sendMessage object message = do
+  (opcode, pairs) <- putUp object message
+  let (putBodyParts, partLengths) = unzip pairs
+  let putBody = mconcat putBodyParts
+  let bodyLength = foldr (+) 0 partLengths
+  let body = runPut putBody
   traceM $ "-> " <> showObjectMessage object message
-  sendRawMessage messageWithHeader
+  sendRawMessage $ messageWithHeader opcode body
   where
-    body :: BSL.ByteString
-    opcode :: Opcode
-    (opcode, body) = runPutM $ putUp object message
-    messageWithHeader :: Put
-    messageWithHeader = do
+    messageWithHeader :: Opcode -> BSL.ByteString -> Put
+    messageWithHeader opcode body = do
       putWord32host $ objectId object
       putWord32host $ (fromIntegral msgSize `shiftL` 16) .|. fromIntegral opcode
       putLazyByteString body
-    msgSize :: Word16
-    msgSize =
-      if msgSizeInteger <= fromIntegral (maxBound :: Word16)
-        then fromIntegral msgSizeInteger
-        else error "Message too large"
-    -- TODO: body length should be returned from `putMessage`, instead of realizing it to a ByteString here
-    msgSizeInteger :: Integer
-    msgSizeInteger = 8 + fromIntegral (BSL.length body)
+      where
+        msgSize :: Word16
+        msgSize =
+          if msgSizeInteger <= fromIntegral (maxBound :: Word16)
+            then fromIntegral msgSizeInteger
+            else error "Message too large"
+        -- TODO: body length should be returned from `putMessage`, instead of realizing it to a ByteString here
+        msgSizeInteger :: Integer
+        msgSizeInteger = 8 + fromIntegral (BSL.length body)
 
 
 receiveMessages :: IsSide s => ProtocolM s ()
@@ -591,8 +601,8 @@ getMessageAction
   -> Opcode
   -> Get (ProtocolM s ())
 getMessageAction object@(Object _ objectHandler) opcode = do
-  message <- getDown object opcode
-  pure $ handleMessage objectHandler object message
+  verifyMessage <- getDown object opcode
+  pure $ handleMessage objectHandler object =<< verifyMessage
 
 type RawMessage = (GenericObjectId, Opcode, BSL.ByteString)
 
@@ -627,13 +637,18 @@ getWaylandBlob = do
   skipPadding
   pure string
 
-putWaylandBlob :: BS.ByteString -> Put
-putWaylandBlob blob = do
-  let size = BS.length blob
-  putWord32host (fromIntegral (size + 1))
-  putByteString blob
-  putWord8 0
-  replicateM_ (padding size) (putWord8 0)
+putWaylandBlob :: BS.ByteString -> (Put, Int)
+putWaylandBlob blob = (putBlob, 4 + len + pad)
+  where
+    -- Total data length including null byte
+    len = BS.length blob + 1
+    -- Padding length
+    pad = padding len
+    putBlob = do
+      putWord32host (fromIntegral (len + 1))
+      putByteString blob
+      putWord8 0
+      replicateM_ pad (putWord8 0)
 
 
 skipPadding :: Get ()
diff --git a/src/Quasar/Wayland/Protocol/Display.hs b/src/Quasar/Wayland/Protocol/Display.hs
index 3b58452..85a3011 100644
--- a/src/Quasar/Wayland/Protocol/Display.hs
+++ b/src/Quasar/Wayland/Protocol/Display.hs
@@ -3,7 +3,6 @@ module Quasar.Wayland.Protocol.Display (
 ) where
 
 import Control.Monad.Catch
-import Data.ByteString.UTF8 qualified as BS
 import Data.HashMap.Strict qualified as HM
 import Quasar.Prelude
 import Quasar.Wayland.Protocol.Core
@@ -19,7 +18,7 @@ clientWlDisplayCallback :: IsInterfaceSide 'Client I_wl_display => Callback 'Cli
 clientWlDisplayCallback = internalFnCallback handler
   where
     -- | wl_display is specified to never change, so manually specifying the callback is safe
-    handler :: Object 'Client I_wl_display -> E_wl_display -> ProtocolM 'Client ()
+    handler :: Object 'Client I_wl_display -> WireEvent_wl_display -> ProtocolM 'Client ()
     -- TODO parse oId
-    handler _ (E_wl_display_error oId code message) = throwM $ ServerError code (BS.toString message)
-    handler _ (E_wl_display_delete_id deletedId) = pure () -- TODO confirm delete
+    handler _ (WireEvent_wl_display_error oId code message) = throwM $ ServerError code (toString message)
+    handler _ (WireEvent_wl_display_delete_id deletedId) = pure () -- TODO confirm delete
diff --git a/src/Quasar/Wayland/Protocol/TH.hs b/src/Quasar/Wayland/Protocol/TH.hs
index 9f27a5f..85f60e9 100644
--- a/src/Quasar/Wayland/Protocol/TH.hs
+++ b/src/Quasar/Wayland/Protocol/TH.hs
@@ -97,15 +97,15 @@ interfaceDecs interface = do
     rT :: Q Type
     rT = if length interface.requests > 0 then conT rTypeName else [t|Void|]
     rTypeName :: Name
-    rTypeName = mkName $ "R_" <> interface.name
+    rTypeName = mkName $ "WireRequest_" <> interface.name
     rConName :: RequestSpec -> Name
-    rConName (RequestSpec request) = mkName $ "R_" <> interface.name <> "_" <> request.name
+    rConName (RequestSpec request) = mkName $ "WireRequest_" <> interface.name <> "_" <> request.name
     eT :: Q Type
     eT = if length interface.events > 0 then conT eTypeName else [t|Void|]
     eTypeName :: Name
-    eTypeName = mkName $ "E_" <> interface.name
+    eTypeName = mkName $ "WireEvent_" <> interface.name
     eConName :: EventSpec -> Name
-    eConName (EventSpec event) = mkName $ "E_" <> interface.name <> "_" <> event.name
+    eConName (EventSpec event) = mkName $ "WireEvent_" <> interface.name <> "_" <> event.name
     requestContext :: RequestSpec -> MessageContext
     requestContext req@(RequestSpec msgSpec) = MessageContext {
       msgInterfaceT = iT,
@@ -246,7 +246,7 @@ isMessageInstanceD t msgs = instanceD (pure []) [t|IsMessage $t|] [opcodeNameD,
     getMessageClause msg = clause [wildP, litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB getMessageE) []
       where
         getMessageE :: Q Exp
-        getMessageE = applyA (conE (msg.msgConName)) ((\argT -> [|getArgument @($argT)|]) . argumentSpecType <$> msg.msgSpec.arguments)
+        getMessageE = applyALifted (conE (msg.msgConName)) ((\argT -> [|getArgument @($argT)|]) . argumentSpecType <$> msg.msgSpec.arguments)
     getMessageInvalidOpcodeClause :: Q Clause
     getMessageInvalidOpcodeClause = do
       let object = mkName "object"
@@ -258,10 +258,10 @@ isMessageInstanceD t msgs = instanceD (pure []) [t|IsMessage $t|] [opcodeNameD,
     putMessageClauseD msg = clause [msgConP msg] (normalB (putMessageE msg.msgSpec.arguments)) []
       where
         putMessageE :: [ArgumentSpec] -> Q Exp
-        putMessageE [] = opcodeE
-        putMessageE args = doE (((\arg -> noBindS [|putArgument @($(argumentSpecType arg)) $(msgArgE msg arg)|]) <$> args) <> [noBindS opcodeE])
-        opcodeE :: Q Exp
-        opcodeE = [|pure $(litE $ integerL $ fromIntegral msg.msgSpec.opcode)|]
+        putMessageE args = [|($(litE $ integerL $ fromIntegral msg.msgSpec.opcode), ) <$> $(putMessageBodyE args)|]
+        putMessageBodyE :: [ArgumentSpec] -> Q Exp
+        putMessageBodyE [] = [|pure []|]
+        putMessageBodyE args = [|sequence $(listE ((\arg -> [|putArgument @($(argumentSpecType arg)) $(msgArgE msg arg)|]) <$> args))|]
 
 
 derivingEq :: Q DerivClause
@@ -309,6 +309,12 @@ applyM con [] = con
 applyM con args = [|join $(applyA con args)|]
 
 
+-- | (a -> b -> c -> d) -> [f (g a), f (g b), f (g c)] -> f (g d)
+applyALifted :: Q Exp -> [Q Exp] -> Q Exp
+applyALifted con [] = [|pure $ pure $con|]
+applyALifted con (monadicE:monadicEs) = foldl (\x y -> [|$x <<*>> $y|]) [|$con <<$>> $monadicE|] monadicEs
+
+
 -- * XML parser
 
 parseProtocol :: MonadFail m => BS.ByteString -> m ProtocolSpec
diff --git a/src/Quasar/Wayland/Registry.hs b/src/Quasar/Wayland/Registry.hs
index 0bb0ef6..0f50133 100644
--- a/src/Quasar/Wayland/Registry.hs
+++ b/src/Quasar/Wayland/Registry.hs
@@ -15,7 +15,7 @@ import Quasar.Wayland.Protocol.Generated
 
 data ClientRegistry = ClientRegistry {
   wlRegistry :: Object 'Client I_wl_registry,
-  globalsVar :: TVar (HM.HashMap Word32 (BS.ByteString, Word32))
+  globalsVar :: TVar (HM.HashMap Word32 (WlString, Word32))
 }
 
 createClientRegistry :: Object 'Client I_wl_display -> ProtocolM 'Client ClientRegistry
@@ -23,7 +23,7 @@ createClientRegistry wlDisplay = mfix \clientRegistry -> do
   globalsVar <- lift $ newTVar HM.empty
 
   (wlRegistry, newId) <- newObject @'Client @I_wl_registry (traceCallback (callback clientRegistry))
-  sendMessage wlDisplay $ R_wl_display_get_registry newId
+  sendMessage wlDisplay $ WireRequest_wl_display_get_registry newId
 
   pure ClientRegistry {
     wlRegistry,
@@ -34,10 +34,10 @@ createClientRegistry wlDisplay = mfix \clientRegistry -> do
     callback clientRegistry = internalFnCallback handler
       where
         -- | wl_registry is specified to never change, so manually specifying the callback is safe
-        handler :: Object 'Client I_wl_registry -> E_wl_registry -> ProtocolM 'Client ()
-        handler _ (E_wl_registry_global name interface version) = do
+        handler :: Object 'Client I_wl_registry -> WireEvent_wl_registry -> ProtocolM 'Client ()
+        handler _ (WireEvent_wl_registry_global name interface version) = do
           lift $ modifyTVar clientRegistry.globalsVar (HM.insert name (interface, version))
-        handler _ (E_wl_registry_global_remove name) = do
+        handler _ (WireEvent_wl_registry_global_remove name) = do
           result <- lift $ stateTVar clientRegistry.globalsVar (swap . lookupDelete name)
           case result of
             Nothing -> traceM $ "Invalid global removed by server: " <> show name
-- 
GitLab