From a5d7adc38ca793aaa2e5f7cc1883bb991ff16bfc Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Tue, 7 Dec 2021 22:26:07 +0100
Subject: [PATCH] Move message call proxies to GetField instances

---
 src/Quasar/Wayland/Protocol/Core.hs |  11 +--
 src/Quasar/Wayland/Protocol/TH.hs   | 114 +++++++++++++++-------------
 2 files changed, 65 insertions(+), 60 deletions(-)

diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs
index e4cff7c..5f76303 100644
--- a/src/Quasar/Wayland/Protocol/Core.hs
+++ b/src/Quasar/Wayland/Protocol/Core.hs
@@ -201,7 +201,6 @@ interfaceName :: forall i. IsInterface i => String
 interfaceName = symbolVal @(InterfaceName i) Proxy
 
 class IsSide (s :: Side) where
-  type Up s i
   type Down s i
   type WireUp s i
   type WireDown s i
@@ -209,7 +208,6 @@ class IsSide (s :: Side) where
   maximumId :: Word32
 
 instance IsSide 'Client where
-  type Up 'Client i = Requests 'Client i
   type Down 'Client i = Events 'Client i
   type WireUp 'Client i = WireRequest i
   type WireDown 'Client i = WireEvent i
@@ -218,7 +216,6 @@ instance IsSide 'Client where
   maximumId = 0xfeffffff
 
 instance IsSide 'Server where
-  type Up 'Server i = Events 'Server i
   type Down 'Server i = Requests 'Server i
   type WireUp 'Server i = WireEvent i
   type WireDown 'Server i = WireRequest i
@@ -233,7 +230,6 @@ class (
     IsMessage (WireDown s i)
   )
   => IsInterfaceSide (s :: Side) i where
-  createProxy :: Object s i -> Up s i
   handleMessage :: Object s i -> WireDown s i -> STM ()
 
 
@@ -255,7 +251,6 @@ data Side = Client | Server
 data Object s i = IsInterfaceSide s i => Object {
   objectProtocol :: (ProtocolHandle s),
   objectObjectId :: GenericObjectId,
-  objectUp :: (Up s i),
   objectDown :: (Down s i),
   objectWireCallback :: (WireCallback s i)
 }
@@ -472,7 +467,7 @@ initializeProtocol wlDisplayWireCallback initializationAction = do
   }
   writeTVar stateVar (Right state)
 
-  let wlDisplay = Object protocol wlDisplayId (createProxy wlDisplay) undefined wlDisplayWireCallback
+  let wlDisplay = Object protocol wlDisplayId undefined wlDisplayWireCallback
   modifyTVar' objectsVar (HM.insert wlDisplayId (SomeObject wlDisplay))
 
   result <- runReaderT (initializationAction wlDisplay) state
@@ -566,7 +561,7 @@ newObjectFromId (NewId oId) callback = do
   protocol <- askProtocol
   let
     genericObjectId = toGenericObjectId oId
-    object = Object protocol genericObjectId (createProxy object) undefined callback
+    object = Object protocol genericObjectId undefined callback
     someObject = SomeObject object
   modifyProtocolVar (.objectsVar) (HM.insert genericObjectId someObject)
   pure object
@@ -633,7 +628,7 @@ getMessageAction
   => Object s i
   -> Opcode
   -> Get (ProtocolM s ())
-getMessageAction object@(Object _ _ _ _ objectHandler) opcode = do
+getMessageAction object@(Object _ _ _ objectHandler) opcode = do
   verifyMessage <- getWireDown object opcode
   pure do
     message <- verifyMessage
diff --git a/src/Quasar/Wayland/Protocol/TH.hs b/src/Quasar/Wayland/Protocol/TH.hs
index a00f285..77ab7d3 100644
--- a/src/Quasar/Wayland/Protocol/TH.hs
+++ b/src/Quasar/Wayland/Protocol/TH.hs
@@ -7,7 +7,7 @@ import Control.Monad.Writer
 import Data.ByteString qualified as BS
 import Data.List (intersperse, singleton)
 import Data.Void (absurd)
-import GHC.Records (getField)
+import GHC.Records
 import Language.Haskell.TH
 import Language.Haskell.TH.Syntax (BangType, VarBangType, addDependentFile)
 import Prelude qualified
@@ -78,10 +78,14 @@ isNewId GenericNewIdArgument = True
 isNewId _ = False
 
 
-toDoc :: Maybe DescriptionSpec -> Maybe String
-toDoc (Just DescriptionSpec{content = Just x}) = Just x
-toDoc (Just DescriptionSpec{summary = Just x}) = Just x
-toDoc _ = Nothing
+toWlDoc :: Maybe DescriptionSpec -> Maybe String
+toWlDoc (Just DescriptionSpec{content = Just x}) = Just x
+toWlDoc (Just DescriptionSpec{summary = Just x}) = Just x
+toWlDoc _ = Nothing
+
+withWlDoc :: Maybe DescriptionSpec -> Q Dec -> Q Dec
+withWlDoc (toWlDoc -> Just doc) = withDecDoc doc
+withWlDoc _ = id
 
 
 generateWaylandProcol :: FilePath -> Q [Dec]
@@ -105,34 +109,38 @@ interfaceDecs interface = do
   public <- execWriterT do
     -- Main interface type
     let iCtorDec = (normalC iName [], Nothing, [])
-    tellQ $ dataD_doc (pure []) iName [] Nothing [iCtorDec] [] (toDoc interface.description)
+    tellQ $ dataD_doc (pure []) iName [] Nothing [iCtorDec] [] (toWlDoc interface.description)
     -- IsInterface instance
     tellQ $ instanceD (pure []) [t|IsInterface $iT|] [
       tySynInstD (tySynEqn Nothing [t|$(conT ''Requests) $sT $iT|] (orUnit (requestsT interface sT))),
       tySynInstD (tySynEqn Nothing [t|$(conT ''Events) $sT $iT|] (orUnit (eventsT interface sT))),
       tySynInstD (tySynEqn Nothing (appT (conT ''WireRequest) iT) wireRequestT),
-      tySynInstD (tySynEqn Nothing (appT (conT ''WireEvent) iT) eT),
+      tySynInstD (tySynEqn Nothing (appT (conT ''WireEvent) iT) wireEventT),
       tySynInstD (tySynEqn Nothing (appT (conT ''InterfaceName) iT) (litT (strTyLit interface.name)))
       ]
     -- | IsInterfaceSide instance
     tellQs interfaceSideInstanceDs
 
-    -- | Requests record
     when (length interface.requests > 0) do
-      tellQ requestRecordD
+      -- | Requests record
+      tellQ requestCallbackRecordD
+      -- | Request proxies
+      tellQs requestProxyInstanceDecs
 
-    -- | Events record
     when (length interface.events > 0) do
-      tellQ eventRecordD
+      -- | Events record
+      tellQ eventCallbackRecordD
+      -- | Event proxies
+      tellQs eventProxyInstanceDecs
 
   internals <- execWriterT do
     -- | Request wire type
     when (length interface.requests > 0) do
-      tellQs $ messageTypeDecs rTypeName requestContexts
+      tellQs $ messageTypeDecs rTypeName wireRequestContexts
 
     -- | Event wire type
     when (length interface.events > 0) do
-      tellQs $ messageTypeDecs eTypeName eventContexts
+      tellQs $ messageTypeDecs eTypeName wireEventContexts
 
   pure (public, internals)
 
@@ -146,36 +154,42 @@ interfaceDecs interface = do
     rTypeName = mkName $ "WireRequest_" <> interface.name
     rConName :: RequestSpec -> Name
     rConName (RequestSpec request) = mkName $ "WireRequest_" <> interface.name <> "_" <> request.name
-    eT :: Q Type
-    eT = if length interface.events > 0 then conT eTypeName else [t|Void|]
+    wireEventT :: Q Type
+    wireEventT = if length interface.events > 0 then conT eTypeName else [t|Void|]
     eTypeName :: Name
     eTypeName = mkName $ "WireEvent_" <> interface.name
     eConName :: EventSpec -> Name
     eConName (EventSpec event) = mkName $ "WireEvent_" <> interface.name <> "_" <> event.name
-    requestContext :: RequestSpec -> MessageContext
-    requestContext req@(RequestSpec msgSpec) = MessageContext {
+    wireRequestContext :: RequestSpec -> MessageContext
+    wireRequestContext req@(RequestSpec msgSpec) = MessageContext {
       msgInterfaceT = iT,
       msgT = wireRequestT,
       msgConName = rConName req,
       msgInterfaceSpec = interface,
       msgSpec = msgSpec
     }
-    requestContexts = requestContext <$> interface.requests
-    eventContext :: EventSpec -> MessageContext
-    eventContext ev@(EventSpec msgSpec) = MessageContext {
+    wireRequestContexts = wireRequestContext <$> interface.requests
+    wireEventContext :: EventSpec -> MessageContext
+    wireEventContext ev@(EventSpec msgSpec) = MessageContext {
       msgInterfaceT = iT,
-      msgT = eT,
+      msgT = wireEventT,
       msgConName = eConName ev,
       msgInterfaceSpec = interface,
       msgSpec = msgSpec
     }
-    eventContexts = eventContext <$> interface.events
+    wireEventContexts = wireEventContext <$> interface.events
+
+    requestCallbackRecordD :: Q Dec
+    requestCallbackRecordD = messageRecordD (requestsName interface) wireRequestContexts
+
+    requestProxyInstanceDecs :: Q [Dec]
+    requestProxyInstanceDecs = messageProxyInstanceDecs [t|'Client|] wireRequestContexts
 
-    requestRecordD :: Q Dec
-    requestRecordD = messageRecordD (requestsName interface) requestContexts
+    eventCallbackRecordD :: Q Dec
+    eventCallbackRecordD = messageRecordD (eventsName interface) wireEventContexts
 
-    eventRecordD :: Q Dec
-    eventRecordD = messageRecordD (eventsName interface) eventContexts
+    eventProxyInstanceDecs :: Q [Dec]
+    eventProxyInstanceDecs = messageProxyInstanceDecs [t|'Server|] wireEventContexts
 
     objectName = mkName "object"
     objectP = varP objectName
@@ -183,33 +197,12 @@ interfaceDecs interface = do
 
     interfaceSideInstanceDs :: Q [Dec]
     interfaceSideInstanceDs = execWriterT do
-      tellQ $ instanceD (pure []) ([t|IsInterfaceSide 'Client $iT|]) [createProxyD Client, handleMessageD Client]
-      tellQ $ instanceD (pure []) ([t|IsInterfaceSide 'Server $iT|]) [createProxyD Server, handleMessageD Server]
-
-    createProxyD :: Side -> Q Dec
-    createProxyD Client = funD 'createProxy [clause [objectP] (normalB requestsProxyE) (sendMessageProxy <$> requestContexts)]
-    createProxyD Server = funD 'createProxy [clause [objectP] (normalB eventsProxyE) (sendMessageProxy <$> eventContexts)]
-    requestsProxyE :: Q Exp
-    requestsProxyE
-      | length interface.requests > 0 = recConE (requestsName interface) (sendMessageProxyField <$> requestContexts)
-      | otherwise = [|()|]
-    eventsProxyE :: Q Exp
-    eventsProxyE
-      | length interface.events > 0 = recConE (eventsName interface) (sendMessageProxyField <$> eventContexts)
-      | otherwise = [|()|]
-
-    sendMessageProxyField :: MessageContext -> Q (Name, Exp)
-    sendMessageProxyField msg = (messageFieldName msg, ) <$> varE (sendMessageFunctionName msg)
-
-    sendMessageFunctionName :: MessageContext -> Name
-    sendMessageFunctionName msg = mkName $ "send_" <> messageFieldNameString msg
-
-    sendMessageProxy :: MessageContext -> Q Dec
-    sendMessageProxy msg = funD (sendMessageFunctionName msg) [clause (msgArgPats msg) (normalB [|objectSendMessage object $(msgE msg)|]) []]
+      tellQ $ instanceD (pure []) ([t|IsInterfaceSide 'Client $iT|]) [handleMessageD Client]
+      tellQ $ instanceD (pure []) ([t|IsInterfaceSide 'Server $iT|]) [handleMessageD Server]
 
     handleMessageD :: Side -> Q Dec
-    handleMessageD Client = funD 'handleMessage (handleMessageClauses eventContexts)
-    handleMessageD Server = funD 'handleMessage (handleMessageClauses requestContexts)
+    handleMessageD Client = funD 'handleMessage (handleMessageClauses wireEventContexts)
+    handleMessageD Server = funD 'handleMessage (handleMessageClauses wireRequestContexts)
 
     handleMessageClauses :: [MessageContext] -> [Q Clause]
     handleMessageClauses [] = [clause [wildP] (normalB [|absurd|]) []]
@@ -225,6 +218,24 @@ interfaceDecs interface = do
         bodyE :: Q Exp
         bodyE = applyMsgArgs msg fieldE
 
+messageProxyInstanceDecs :: Q Type -> [MessageContext] -> Q [Dec]
+messageProxyInstanceDecs sideT messageContexts = mapM messageProxyInstanceD messageContexts
+  where
+    messageProxyInstanceD :: MessageContext -> Q Dec
+    messageProxyInstanceD msg = instanceD (pure []) instanceT [
+      funD 'getField [clause ([varP objectName] <> msgArgPats msg) (normalB [|objectSendMessage object $(msgE msg)|]) []]
+      ]
+      where
+        objectName = mkName "object"
+        instanceT :: Q Type
+        instanceT = [t|HasField $(litT (strTyLit msg.msgSpec.name)) $objectT $proxyT|]
+        objectT :: Q Type
+        objectT = [t|Object $sideT $(msg.msgInterfaceT)|]
+        proxyT :: Q Type
+        proxyT = [t|$(applyArgTypes [t|STM ()|])|]
+        applyArgTypes :: Q Type -> Q Type
+        applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType <$> msg.msgSpec.arguments)
+
 
 messageFieldName :: MessageContext -> Name
 messageFieldName msg = mkName $ messageFieldNameString msg
@@ -243,7 +254,6 @@ messageRecordD name messageContexts = dataD (cxt []) name [plainTV sideTVarName]
         applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType <$> msg.msgSpec.arguments)
 
 
-
 sideTVarName :: Name
 sideTVarName = mkName "s"
 sideTVar :: Q Type
-- 
GitLab