From 47cca98a8eba2f119fef1d992c3dded6302b6882 Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Tue, 21 Sep 2021 17:53:18 +0200
Subject: [PATCH] Apply argument type directly (removing ArgumentType type
 family)

---
 src/Quasar/Wayland/Protocol/Core.hs | 121 +++++++++++-----------------
 src/Quasar/Wayland/Protocol/TH.hs   |  96 +++++++++++++---------
 2 files changed, 108 insertions(+), 109 deletions(-)

diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs
index cb0dd25..21855a4 100644
--- a/src/Quasar/Wayland/Protocol/Core.hs
+++ b/src/Quasar/Wayland/Protocol/Core.hs
@@ -4,8 +4,8 @@ module Quasar.Wayland.Protocol.Core (
   ObjectId,
   GenericObjectId,
   NewId,
+  GenericNewId,
   Opcode,
-  ArgumentType(..),
   Fixed(..),
   WlString(..),
   toString,
@@ -46,9 +46,6 @@ module Quasar.Wayland.Protocol.Core (
   MaximumIdReached(..),
   ServerError(..),
 
-  -- * TH utilities
-  isNewId,
-
   -- * Message decoder operations
   WireFormat(..),
   dropRemaining,
@@ -74,23 +71,26 @@ import Data.Proxy
 import Data.String (IsString(..))
 import Data.Void (absurd)
 import GHC.TypeLits
-import Language.Haskell.TH.Syntax (Lift)
 import Quasar.Prelude
 
 
-newtype ObjectId (j :: Symbol) = ObjectId GenericObjectId
-  deriving stock (Eq, Show)
+newtype ObjectId (j :: Symbol) = ObjectId Word32
+  deriving newtype (Eq, Show, Hashable)
+
+newtype GenericObjectId = GenericObjectId Word32
+  deriving newtype (Eq, Show, Hashable)
 
-type GenericObjectId = Word32
+toGenericObjectId :: ObjectId j -> GenericObjectId
+toGenericObjectId (ObjectId oId) = GenericObjectId oId
 
 type Opcode = Word16
 
 
-newtype NewId (j :: Symbol) = NewId GenericObjectId
-  deriving stock (Eq, Show)
+newtype NewId (j :: Symbol) = NewId (ObjectId j)
+  deriving newtype (Eq, Show)
 
 newtype GenericNewId = GenericNewId GenericObjectId
-  deriving stock (Eq, Show)
+  deriving newtype (Eq, Show)
 
 
 -- | Signed 24.8 decimal numbers.
@@ -121,86 +121,57 @@ dropRemaining :: Get ()
 dropRemaining = void getRemainingLazyByteString
 
 
-data ArgumentType
-  = IntArgument
-  | UIntArgument
-  | FixedArgument
-  | StringArgument
-  | ArrayArgument
-  | ObjectArgument String
-  | GenericObjectArgument
-  | NewIdArgument String
-  | GenericNewIdArgument
-  | FdArgument
-  deriving stock (Eq, Show, Lift)
-
-isNewId :: ArgumentType -> Bool
-isNewId (NewIdArgument _) = True
-isNewId GenericNewIdArgument = True
-isNewId _ = False
-
-class (Eq (Argument a), Show (Argument a)) => WireFormat a where
-  type Argument a
-  putArgument :: Argument a -> ProtocolM s (Put, Int)
-  getArgument :: Get (ProtocolM s (Argument a))
-  showArgument :: Argument a -> String
-
-instance WireFormat 'IntArgument where
-  type Argument 'IntArgument = Int32
+class (Eq a, Show a) => WireFormat a where
+  putArgument :: a -> ProtocolM s (Put, Int)
+  getArgument :: Get (ProtocolM s a)
+  showArgument :: a -> String
+
+instance WireFormat Int32 where
   putArgument x = pure (putInt32host x, 4)
   getArgument = pure <$> getInt32host
   showArgument = show
 
-instance WireFormat 'UIntArgument where
-  type Argument 'UIntArgument = Word32
+instance WireFormat Word32 where
   putArgument x = pure (putWord32host x, 4)
   getArgument = pure <$> getWord32host
   showArgument = show
 
-instance WireFormat 'FixedArgument where
-  type Argument 'FixedArgument = Fixed
+instance WireFormat Fixed where
   putArgument (Fixed repr) = pure (putWord32host repr, 4)
   getArgument = pure . Fixed <$> getWord32host
   showArgument = show
 
-instance WireFormat 'StringArgument where
-  type Argument 'StringArgument = WlString
+instance WireFormat WlString where
   putArgument (WlString x) = pure $ putWaylandBlob x
   getArgument = pure . WlString <$> getWaylandBlob
   showArgument = show
 
-instance WireFormat 'ArrayArgument where
-  type Argument 'ArrayArgument = BS.ByteString
+instance WireFormat BS.ByteString where
   putArgument x = pure $ putWaylandBlob x
   getArgument = pure <$> getWaylandBlob
   showArgument array = "[array " <> show (BS.length array) <> "B]"
 
 instance KnownSymbol j => WireFormat (ObjectId (j :: Symbol)) where
-  type Argument (ObjectId j) = ObjectId j
   putArgument (ObjectId oId) = pure (putWord32host oId, 4)
   getArgument = pure . ObjectId <$> getWord32host
   showArgument (ObjectId oId) = symbolVal @j Proxy <> "@" <> show oId
 
-instance WireFormat 'GenericObjectArgument where
-  type Argument 'GenericObjectArgument = GenericObjectId
-  putArgument x = pure (putWord32host x, 4)
-  getArgument = pure <$> getWord32host
+instance WireFormat GenericObjectId where
+  putArgument (GenericObjectId oId) = pure (putWord32host oId, 4)
+  getArgument = pure . GenericObjectId <$> getWord32host
   showArgument oId = "[unknown]@" <> show oId
 
 instance KnownSymbol j => WireFormat (NewId (j :: Symbol)) where
-  type Argument (NewId j) = NewId j
-  putArgument (NewId newId) = pure (putWord32host newId, 4)
-  getArgument = pure . NewId <$> getWord32host
+  putArgument (NewId newId) = putArgument newId
+  getArgument = NewId <<$>> getArgument
   showArgument (NewId newId) = "new " <> symbolVal @j Proxy <> "@" <> show newId
 
-instance WireFormat 'GenericNewIdArgument where
-  type Argument 'GenericNewIdArgument = GenericNewId
-  putArgument (GenericNewId newId) = pure (putWord32host newId, 4)
-  getArgument = pure . GenericNewId <$> getWord32host
+instance WireFormat GenericNewId where
+  putArgument (GenericNewId newId) = putArgument newId
+  getArgument = GenericNewId <<$>> getArgument
   showArgument newId = "new [unknown]@" <> show newId
 
-instance WireFormat 'FdArgument where
-  type Argument 'FdArgument = Void
+instance WireFormat Void where
   putArgument = undefined
   getArgument = undefined
   showArgument = undefined
@@ -220,8 +191,8 @@ class (
 class IsSide (s :: Side) where
   type Up s i
   type Down s i
-  initialId :: GenericObjectId
-  maximumId :: GenericObjectId
+  initialId :: Word32
+  maximumId :: Word32
 
 instance IsSide 'Client where
   type Up 'Client i = Request i
@@ -435,7 +406,7 @@ initializeProtocol wlDisplayCallback initializationAction = do
   bytesSentVar <- newTVar 0
   inboxDecoderVar <- newTVar $ runGetIncremental getRawMessage
   outboxVar <- newTVar Nothing
-  objectsVar <- newTVar $ HM.fromList [(1, (SomeObject wlDisplay))]
+  objectsVar <- newTVar $ HM.fromList [(wlDisplayId, (SomeObject wlDisplay))]
   nextIdVar <- newTVar (initialId @s)
   let state = ProtocolState {
     bytesReceivedVar,
@@ -452,8 +423,10 @@ initializeProtocol wlDisplayCallback initializationAction = do
   result <- runReaderT (initializationAction wlDisplay) state
   pure (result, protocol)
   where
+    wlDisplayId :: GenericObjectId
+    wlDisplayId = GenericObjectId 1
     wlDisplay :: Object s wl_display
-    wlDisplay = Object 1 wlDisplayCallback
+    wlDisplay = Object wlDisplayId wlDisplayCallback
 
 -- | Run a protocol action in 'IO'. If an exception occurs, it is stored as a protocol failure and is then
 -- re-thrown.
@@ -516,12 +489,12 @@ newObject
   => Callback s i
   -> ProtocolM s (Object s i, NewId (InterfaceName i))
 newObject callback = do
-  genOId <- allocateObjectId
-  let oId = NewId @(InterfaceName i) genOId
-  object <- newObjectFromId oId callback
-  pure (object, oId)
+  oId <- allocateObjectId
+  let newId = NewId @(InterfaceName i) oId
+  object <- newObjectFromId newId callback
+  pure (object, newId)
   where
-    allocateObjectId :: ProtocolM s GenericObjectId
+    allocateObjectId :: ProtocolM s (ObjectId (InterfaceName i))
     allocateObjectId = do
       id' <- readProtocolVar (.nextIdVar)
 
@@ -529,7 +502,7 @@ newObject callback = do
       when (nextId' == maximumId @s) $ throwM MaximumIdReached
 
       writeProtocolVar (.nextIdVar) nextId'
-      pure id'
+      pure $ ObjectId id'
 
 newObjectFromId
   :: forall s i. IsInterfaceSide s i
@@ -538,9 +511,10 @@ newObjectFromId
   -> ProtocolM s (Object s i)
 newObjectFromId (NewId oId) callback = do
   let
-    object = Object oId callback
+    genericObjectId = toGenericObjectId oId
+    object = Object genericObjectId callback
     someObject = SomeObject object
-  modifyProtocolVar (.objectsVar) (HM.insert oId someObject)
+  modifyProtocolVar (.objectsVar) (HM.insert genericObjectId someObject)
   pure object
 
 
@@ -557,10 +531,11 @@ sendMessage object message = do
   where
     messageWithHeader :: Opcode -> BSL.ByteString -> Put
     messageWithHeader opcode body = do
-      putWord32host $ objectId object
+      putWord32host objectIdWord
       putWord32host $ (fromIntegral msgSize `shiftL` 16) .|. fromIntegral opcode
       putLazyByteString body
       where
+        (GenericObjectId objectIdWord) = objectId object
         msgSize :: Word16
         msgSize =
           if msgSizeInteger <= fromIntegral (maxBound :: Word16)
@@ -622,7 +597,7 @@ receiveRawMessage = do
 
 getRawMessage :: Get RawMessage
 getRawMessage = do
-  oId <- getWord32host
+  oId <- GenericObjectId <$> getWord32host
   sizeAndOpcode <- getWord32host
   let
     size = fromIntegral (sizeAndOpcode `shiftR` 16) - 8
diff --git a/src/Quasar/Wayland/Protocol/TH.hs b/src/Quasar/Wayland/Protocol/TH.hs
index 85f60e9..c687e5b 100644
--- a/src/Quasar/Wayland/Protocol/TH.hs
+++ b/src/Quasar/Wayland/Protocol/TH.hs
@@ -5,8 +5,10 @@ module Quasar.Wayland.Protocol.TH (
 import Control.Monad.Writer
 import Data.ByteString qualified as BS
 import Language.Haskell.TH
-import Language.Haskell.TH.Syntax (BangType, addDependentFile)
+import Language.Haskell.TH.Syntax (BangType, VarBangType, addDependentFile)
 import Language.Haskell.TH.Syntax qualified as TH
+import Data.ByteString qualified as BS
+import Data.Int (Int32)
 import Data.List (intersperse)
 import Prelude qualified
 import Quasar.Prelude
@@ -48,6 +50,25 @@ data ArgumentSpec = ArgumentSpec {
 }
   deriving stock Show
 
+data ArgumentType
+  = IntArgument
+  | UIntArgument
+  | FixedArgument
+  | StringArgument
+  | ArrayArgument
+  | ObjectArgument String
+  | GenericObjectArgument
+  | NewIdArgument String
+  | GenericNewIdArgument
+  | FdArgument
+  deriving stock (Eq, Show)
+
+isNewId :: ArgumentType -> Bool
+isNewId (NewIdArgument _) = True
+isNewId GenericNewIdArgument = True
+isNewId _ = False
+
+
 
 generateWaylandProcol :: FilePath -> Q [Dec]
 generateWaylandProcol protocolFile = do
@@ -71,8 +92,8 @@ tellQs = tell <=< lift
 interfaceDecs :: InterfaceSpec -> Q ([Dec], [Dec])
 interfaceDecs interface = do
   public <- execWriterT do
-    tellQ requestClassD
-    tellQ eventClassD
+    tellQ requestRecordD
+    tellQ eventRecordD
   internals <- execWriterT do
     tellQ $ dataD (pure []) iName [] Nothing [normalC iName []] [derivingInterfaceClient, derivingInterfaceServer]
     tellQ $ instanceD (pure []) [t|IsInterface $iT|] instanceDecs
@@ -133,29 +154,32 @@ interfaceDecs interface = do
     mName = mkName "m"
     mType :: Q Type
     mType = varT mName
+    sName :: Name
+    sName = mkName "s"
+    sType :: Q Type
+    sType = varT sName
 
-    requestClassD :: Q Dec
-    requestClassD =
-      -- [t|MonadCatch $mType|]
-      classD (cxt []) (requestClassN interface) [plainTV mName, plainTV aName] [] (callSigD <$> requestContexts)
+    requestRecordD :: Q Dec
+    requestRecordD = messageRecordD (requestClassN interface) requestContexts
 
-    eventClassD :: Q Dec
-    eventClassD =
-      -- [t|MonadCatch $mType|]
-      classD (cxt []) (eventClassN interface) [plainTV mName, plainTV aName] [] (callSigD <$> eventContexts)
+    eventRecordD :: Q Dec
+    eventRecordD = messageRecordD (eventClassN interface) requestContexts
 
-    callSigD :: MessageContext -> Q Dec
-    callSigD msg = sigD (mkName (interface.name <> "__" <> msg.msgSpec.name)) [t|$aType -> $(applyArgTypes [t|$mType ()|])|]
+    messageRecordD :: Name -> [MessageContext] -> Q Dec
+    messageRecordD name messageContexts = dataD (cxt []) name [] Nothing [con] []
       where
-        applyArgTypes :: Q Type -> Q Type
-        applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType <$> msg.msgSpec.arguments)
+        con = recC name (recField <$> messageContexts)
+        recField :: MessageContext -> Q VarBangType
+        recField msg = varDefaultBangType (mkName msg.msgSpec.name) [t|$(applyArgTypes [t|forall s. ProtocolM s ()|])|]
+          where
+            applyArgTypes :: Q Type -> Q Type
+            applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType <$> msg.msgSpec.arguments)
+
 
 interfaceSideInstanceDs :: InterfaceSpec -> Q [Dec]
 interfaceSideInstanceDs interface = execWriterT do
   tellQs [d|instance IsInterfaceSide 'Client $iT|]
   tellQs [d|instance IsInterfaceSide 'Server $iT|]
-  --tellQs [d|instance forall m a. IsInterfaceHandler 'Client m $iT a where {handleMessage = undefined}|]
-  --tellQs [d|instance forall m a. IsInterfaceHandler 'Server m $iT a where {handleMessage = undefined}|]
   where
     iT = interfaceT interface
 
@@ -229,7 +253,7 @@ messageTypeDecs name msgs = execWriterT do
         []
       where
         showArgE :: ArgumentSpec -> [Q Exp]
-        showArgE arg = [stringE (arg.name ++ "="), [|showArgument @($(argumentSpecType arg)) $(msgArgE msg arg)|]]
+        showArgE arg = [stringE (arg.name ++ "="), [|showArgument @($(argumentType arg)) $(msgArgE msg arg)|]]
 
 isMessageInstanceD :: Q Type -> [MessageContext] -> Q Dec
 isMessageInstanceD t msgs = instanceD (pure []) [t|IsMessage $t|] [opcodeNameD, getMessageD, putMessageD]
@@ -246,7 +270,7 @@ isMessageInstanceD t msgs = instanceD (pure []) [t|IsMessage $t|] [opcodeNameD,
     getMessageClause msg = clause [wildP, litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB getMessageE) []
       where
         getMessageE :: Q Exp
-        getMessageE = applyALifted (conE (msg.msgConName)) ((\argT -> [|getArgument @($argT)|]) . argumentSpecType <$> msg.msgSpec.arguments)
+        getMessageE = applyALifted (conE (msg.msgConName)) ((\argT -> [|getArgument @($argT)|]) . argumentType <$> msg.msgSpec.arguments)
     getMessageInvalidOpcodeClause :: Q Clause
     getMessageInvalidOpcodeClause = do
       let object = mkName "object"
@@ -261,7 +285,7 @@ isMessageInstanceD t msgs = instanceD (pure []) [t|IsMessage $t|] [opcodeNameD,
         putMessageE args = [|($(litE $ integerL $ fromIntegral msg.msgSpec.opcode), ) <$> $(putMessageBodyE args)|]
         putMessageBodyE :: [ArgumentSpec] -> Q Exp
         putMessageBodyE [] = [|pure []|]
-        putMessageBodyE args = [|sequence $(listE ((\arg -> [|putArgument @($(argumentSpecType arg)) $(msgArgE msg arg)|]) <$> args))|]
+        putMessageBodyE args = [|sequence $(listE ((\arg -> [|putArgument @($(argumentType arg)) $(msgArgE msg arg)|]) <$> args))|]
 
 
 derivingEq :: Q DerivClause
@@ -277,26 +301,26 @@ derivingInterfaceServer :: Q DerivClause
 derivingInterfaceServer = derivClause (Just AnyclassStrategy) [[t|IsInterfaceSide 'Server|]]
 
 argumentType :: ArgumentSpec -> Q Type
-argumentType argSpec = [t|Argument $(promoteArgumentSpecType argSpec.argType)|]
-
-argumentSpecType :: ArgumentSpec -> Q Type
-argumentSpecType argSpec = promoteArgumentSpecType argSpec.argType
-
-promoteArgumentSpecType :: ArgumentType -> Q Type
-promoteArgumentSpecType (ObjectArgument iName) = [t|ObjectId $(litT $ strTyLit iName)|]
-promoteArgumentSpecType (NewIdArgument iName) = [t|NewId $(litT $ strTyLit iName)|]
-promoteArgumentSpecType arg = do
-  argExp <- (TH.lift arg)
-  matchCon argExp
-  where
-    matchCon :: Exp -> Q Type
-    matchCon (ConE name) = pure $ ConT name
-    matchCon (AppE x _) = matchCon x
-    matchCon _ = fail "Can only promote ConE expression"
+argumentType argSpec = promoteArgumentType argSpec.argType
+
+promoteArgumentType :: ArgumentType -> Q Type
+promoteArgumentType IntArgument = [t|Int32|]
+promoteArgumentType UIntArgument = [t|Word32|]
+promoteArgumentType FixedArgument = [t|Fixed|]
+promoteArgumentType StringArgument = [t|WlString|]
+promoteArgumentType ArrayArgument = [t|BS.ByteString|]
+promoteArgumentType (ObjectArgument iName) = [t|ObjectId $(litT (strTyLit iName))|]
+promoteArgumentType GenericObjectArgument = [t|GenericObjectId|]
+promoteArgumentType (NewIdArgument iName) = [t|NewId $(litT (strTyLit iName))|]
+promoteArgumentType GenericNewIdArgument = [t|GenericNewId|]
+promoteArgumentType FdArgument = [t|Void|] -- TODO
 
 defaultBangType :: Q Type -> Q BangType
 defaultBangType = bangType (bang noSourceUnpackedness noSourceStrictness)
 
+varDefaultBangType  :: Name -> Q Type -> Q VarBangType
+varDefaultBangType name qType = varBangType name $ bangType (bang noSourceUnpackedness noSourceStrictness) qType
+
 
 -- | (a -> b -> c -> d) -> [m a, m b, m c] -> m d
 applyA :: Q Exp -> [Q Exp] -> Q Exp
-- 
GitLab