From 5ac56da09288dfa40c830d4841b1620d624b7885 Mon Sep 17 00:00:00 2001 From: Jens Nolte <git@queezle.net> Date: Wed, 15 Sep 2021 22:37:45 +0200 Subject: [PATCH] Add api to create objects; make ids typesafe; remove initialization hack --- src/Quasar/Wayland/Client.hs | 47 ++++++----- src/Quasar/Wayland/Connection.hs | 20 ++--- src/Quasar/Wayland/Protocol/Core.hs | 120 ++++++++++++++++++++-------- src/Quasar/Wayland/Protocol/TH.hs | 46 +++++++---- 4 files changed, 147 insertions(+), 86 deletions(-) diff --git a/src/Quasar/Wayland/Client.hs b/src/Quasar/Wayland/Client.hs index 65e9d5f..658a2ac 100644 --- a/src/Quasar/Wayland/Client.hs +++ b/src/Quasar/Wayland/Client.hs @@ -1,6 +1,7 @@ module Quasar.Wayland.Client ( connectWaylandClient, newWaylandClient, + connectWaylandSocket, ) where import Control.Concurrent.STM @@ -22,42 +23,40 @@ import System.FilePath ((</>), isRelative) import Text.Read (readEither) -data WaylandClient = WaylandClient (WaylandConnection 'Client) +data WaylandClient = WaylandClient (WaylandConnection 'Client) (Object 'Client STM I_wl_display) instance IsResourceManager WaylandClient where - toResourceManager (WaylandClient connection) = toResourceManager connection + toResourceManager (WaylandClient connection _) = toResourceManager connection instance IsDisposable WaylandClient where - toDisposable (WaylandClient connection) = toDisposable connection + toDisposable (WaylandClient connection _) = toDisposable connection newWaylandClient :: MonadResourceManager m => Socket -> m WaylandClient -newWaylandClient socket = WaylandClient <$> - newWaylandConnection - @I_wl_display - @I_wl_registry - (traceCallback ignoreMessage) - -- HACK to send get_registry - (Just (R_wl_display_get_registry (NewId 2))) - (traceCallback ignoreMessage) - socket +newWaylandClient socket = do + (connection, wlDisplay) <- newWaylandConnection @I_wl_display (traceCallback ignoreMessage) socket + + (wlRegistry, newId) <- stepProtocol connection $ newObject @'Client @STM @I_wl_registry (traceCallback ignoreMessage) + stepProtocol connection $ sendMessage wlDisplay $ R_wl_display_get_registry newId + pure $ WaylandClient connection wlDisplay connectWaylandClient :: MonadResourceManager m => m WaylandClient connectWaylandClient = mask_ do socket <- liftIO connectWaylandSocket newWaylandClient socket - where - connectWaylandSocket :: IO Socket - connectWaylandSocket = do - lookupEnv "WAYLAND_SOCKET" >>= \case - -- Parent process already established connection - Just waylandSocketEnv -> do - case readEither waylandSocketEnv of - Left err -> fail $ "Failed to parse WAYLAND_SOCKET: " <> err - Right fd -> Socket.mkSocket fd - Nothing -> do - path <- getWaylandSocketPath - newUnixSocket path +connectWaylandSocket :: IO Socket +connectWaylandSocket = do + lookupEnv "WAYLAND_SOCKET" >>= \case + -- Parent process already established connection + Just waylandSocketEnv -> do + case readEither waylandSocketEnv of + Left err -> fail $ "Failed to parse WAYLAND_SOCKET: " <> err + Right fd -> Socket.mkSocket fd + Nothing -> do + path <- getWaylandSocketPath + newUnixSocket path + + where getWaylandSocketPath :: IO FilePath getWaylandSocketPath = do waylandDisplayEnv <- lookupEnv "WAYLAND_DISPLAY" diff --git a/src/Quasar/Wayland/Connection.hs b/src/Quasar/Wayland/Connection.hs index 8efeed6..f5b0b99 100644 --- a/src/Quasar/Wayland/Connection.hs +++ b/src/Quasar/Wayland/Connection.hs @@ -1,6 +1,7 @@ module Quasar.Wayland.Connection ( WaylandConnection, newWaylandConnection, + stepProtocol, ) where import Control.Concurrent.STM @@ -35,13 +36,11 @@ data SocketClosed = SocketClosed deriving anyclass Exception newWaylandConnection - :: forall wl_display wl_registry s m. (IsInterfaceSide s wl_display, IsInterfaceSide s wl_registry, MonadResourceManager m) + :: forall wl_display wl_registry s m. (IsInterfaceSide s wl_display, MonadResourceManager m) => Callback s STM wl_display - -> Maybe (Up s wl_display) - -> Callback s STM wl_registry -> Socket - -> m (WaylandConnection s) -newWaylandConnection wlDisplayCallback initializationMessage wlRegistryCallback socket = do + -> m (WaylandConnection s, Object s STM wl_display) +newWaylandConnection wlDisplayCallback socket = do protocolStateVar <- liftIO $ newTVarIO protocolState outboxVar <- liftIO newEmptyTMVarIO @@ -61,16 +60,9 @@ newWaylandConnection wlDisplayCallback initializationMessage wlRegistryCallback connectionThread connection $ sendThread connection connectionThread connection $ receiveThread connection - -- Create registry, if requested - forM_ initializationMessage \msg -> - sendProtocolMessage connection wlDisplay msg - - pure connection + pure (connection, wlDisplay) where - (protocolState, wlDisplay) = initialProtocolState wlDisplayCallback wlRegistryCallback - -sendProtocolMessage :: (IsInterfaceSide s i, MonadIO m) => WaylandConnection s -> Object s STM i -> Up s i -> m () -sendProtocolMessage connection object message = stepProtocol connection $ sendMessage object message + (protocolState, wlDisplay) = initialProtocolState wlDisplayCallback stepProtocol :: forall s m a. MonadIO m => WaylandConnection s -> ProtocolStep s STM a -> m a stepProtocol connection step = liftIO do diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs index 6616854..3ab1775 100644 --- a/src/Quasar/Wayland/Protocol/Core.hs +++ b/src/Quasar/Wayland/Protocol/Core.hs @@ -2,6 +2,7 @@ module Quasar.Wayland.Protocol.Core ( ObjectId, + GenericObjectId, NewId(..), Opcode, ArgumentType(..), @@ -22,6 +23,7 @@ module Quasar.Wayland.Protocol.Core ( ProtocolStep, initialProtocolState, sendMessage, + newObject, feedInput, setException, @@ -47,12 +49,18 @@ import Data.ByteString.Lazy qualified as BSL import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict qualified as HM import Data.Maybe (isJust) +import Data.Proxy import Data.Void (absurd) +import GHC.TypeLits import Language.Haskell.TH.Syntax (Lift) import Quasar.Prelude -type ObjectId = Word32 +newtype ObjectId (j :: Symbol) = ObjectId GenericObjectId + deriving stock (Eq, Show) + +type GenericObjectId = Word32 + type Opcode = Word16 -- | Signed 24.8 decimal numbers. @@ -62,15 +70,17 @@ newtype Fixed = Fixed Word32 instance Show Fixed where show x = "[fixed " <> show x <> "]" -newtype NewId = NewId ObjectId - deriving newtype (Eq, Show) +newtype NewId (j :: Symbol) = NewId GenericObjectId + deriving stock (Eq, Show) + +newtype GenericNewId = GenericNewId GenericObjectId + deriving stock (Eq, Show) dropRemaining :: Get () dropRemaining = void getRemainingLazyByteString - data ArgumentType = IntArgument | UIntArgument @@ -78,9 +88,9 @@ data ArgumentType | StringArgument | ArrayArgument | ObjectArgument String - | UnknownObjectArgument + | GenericObjectArgument | NewIdArgument String - | UnknownNewIdArgument + | GenericNewIdArgument | FdArgument deriving stock (Eq, Show, Lift) @@ -120,29 +130,29 @@ instance WireFormat 'ArrayArgument where getArgument = getWaylandBlob showArgument array = "[array " <> show (BS.length array) <> "B]" -instance WireFormat 'ObjectArgument where - type Argument 'ObjectArgument = ObjectId - putArgument = putWord32host - getArgument = getWord32host - showArgument oId = "@" <> show oId +instance KnownSymbol j => WireFormat (ObjectId (j :: Symbol)) where + type Argument (ObjectId j) = ObjectId j + putArgument (ObjectId oId) = putWord32host oId + getArgument = ObjectId <$> getWord32host + showArgument (ObjectId oId) = symbolVal @j Proxy <> "@" <> show oId -instance WireFormat 'UnknownObjectArgument where - type Argument 'UnknownObjectArgument = ObjectId +instance WireFormat 'GenericObjectArgument where + type Argument 'GenericObjectArgument = GenericObjectId putArgument = putWord32host getArgument = getWord32host - showArgument oId = "@" <> show oId + showArgument oId = "[unknown]@" <> show oId -instance WireFormat 'NewIdArgument where - type Argument 'NewIdArgument = NewId +instance KnownSymbol j => WireFormat (NewId (j :: Symbol)) where + type Argument (NewId j) = NewId j putArgument (NewId newId) = putWord32host newId getArgument = NewId <$> getWord32host - showArgument newId = "new @" <> show newId + showArgument (NewId newId) = "new " <> symbolVal @j Proxy <> "@" <> show newId -instance WireFormat 'UnknownNewIdArgument where - type Argument 'UnknownNewIdArgument = NewId - putArgument (NewId newId) = putWord32host newId - getArgument = NewId <$> getWord32host - showArgument newId = "new @" <> show newId +instance WireFormat 'GenericNewIdArgument where + type Argument 'GenericNewIdArgument = GenericNewId + putArgument (GenericNewId newId) = putWord32host newId + getArgument = GenericNewId <$> getWord32host + showArgument newId = "new [unknown]@" <> show newId instance WireFormat 'FdArgument where type Argument 'FdArgument = Void @@ -159,19 +169,27 @@ class ( => IsInterface i where type Request i type Event i + type InterfaceName i :: Symbol interfaceName :: String class IsSide (s :: Side) where type Up s i type Down s i + initialId :: GenericObjectId + maximumId :: GenericObjectId instance IsSide 'Client where type Up 'Client i = Request i type Down 'Client i = Event i + -- Id #1 is reserved for wl_display + initialId = 2 + maximumId = 0xfeffffff instance IsSide 'Server where type Up 'Server i = Event i type Down 'Server i = Request i + initialId = 0xff000000 + maximumId = 0xffffffff --- | Empty class, used to combine constraints @@ -198,10 +216,10 @@ class IsInterfaceSide s i => IsInterfaceHandler s m i a where -- | Data kind data Side = Client | Server -data Object s m i = IsInterfaceSide s i => Object ObjectId (Callback s m i) +data Object s m i = IsInterfaceSide s i => Object GenericObjectId (Callback s m i) class IsObject a where - objectId :: a -> ObjectId + objectId :: a -> GenericObjectId objectInterfaceName :: a -> String class IsObjectSide a where @@ -225,7 +243,7 @@ instance forall s m i. IsInterfaceSide s i => IsObjectSide (Object s m i) where -- | Wayland object quantification wrapper data SomeObject s m = forall i. IsInterfaceSide s i => SomeObject (Object s m i) - | UnknownObject String ObjectId + | UnknownObject String GenericObjectId instance IsObject (SomeObject s m) where objectId (SomeObject object) = objectId object @@ -269,7 +287,8 @@ data ProtocolState (s :: Side) m = ProtocolState { bytesSent :: !Int64, inboxDecoder :: Decoder RawMessage, outbox :: Maybe Put, - objects :: HashMap ObjectId (SomeObject s m) + objects :: HashMap GenericObjectId (SomeObject s m), + nextId :: Word32 } @@ -320,10 +339,15 @@ data ProtocolException = ProtocolException String deriving stock Show deriving anyclass Exception +data MaximumIdReached = MaximumIdReached + deriving stock Show + deriving anyclass Exception + -- * Monad plumbing type ProtocolStep s m a = ProtocolState s m -> m (Either SomeException a, Maybe BSL.ByteString, ProtocolState s m) +-- Must not be exported. 'ProtocolStep' ensures proper protocol failure in case of exceptions. type ProtocolAction s m a = StateT (ProtocolState s m) m a protocolStep :: forall s m a. MonadCatch m => ProtocolAction s m a -> ProtocolStep s m a @@ -345,16 +369,13 @@ protocolStep action inState = do -- * Exported functions initialProtocolState - :: forall wl_display wl_registry s m. (IsInterfaceSide s wl_display, IsInterfaceSide s wl_registry) + :: forall wl_display wl_registry s m. IsInterfaceSide s wl_display => Callback s m wl_display - -> Callback s m wl_registry -> (ProtocolState s m, Object s m wl_display) -initialProtocolState wlDisplayCallback wlRegistryCallback = (initialState, wlDisplay) +initialProtocolState wlDisplayCallback = (initialState, wlDisplay) where wlDisplay :: Object s m wl_display wlDisplay = Object 1 wlDisplayCallback - wlRegistry :: Object s m wl_registry - wlRegistry = Object 2 wlRegistryCallback initialState :: ProtocolState s m initialState = ProtocolState { protocolException = Nothing, @@ -362,7 +383,8 @@ initialProtocolState wlDisplayCallback wlRegistryCallback = (initialState, wlDis bytesSent = 0, inboxDecoder = runGetIncremental getRawMessage, outbox = Nothing, - objects = HM.fromList [(1, (SomeObject wlDisplay)), (2, (SomeObject wlRegistry))] + objects = HM.fromList [(1, (SomeObject wlDisplay))], + nextId = initialId @s } -- | Feed the protocol newly received data @@ -380,6 +402,38 @@ setException :: (MonadCatch m, Exception e) => e -> ProtocolStep s m () setException ex = protocolStep do State.modify \st -> st{protocolException = Just (toException ex)} + +-- Create an object. The caller is responsible for sending the 'NewId' exactly once before using the object. +newObject + :: forall s m i. (IsInterfaceSide s i, MonadCatch m) + => Callback s m i + -> ProtocolStep s m (Object s m i, NewId (InterfaceName i)) +newObject callback = protocolStep $ newObjectInternal callback + +newObjectInternal + :: forall s m i. (IsInterfaceSide s i, MonadCatch m) + => Callback s m i + -> ProtocolAction s m (Object s m i, NewId (InterfaceName i)) +newObjectInternal callback = do + oId <- allocateObjectId @s @m @i + let + object = Object oId callback + someObject = SomeObject object + State.modify \st -> st { objects = HM.insert oId someObject st.objects} + pure (object, NewId oId) + where + allocateObjectId :: forall s m i. (IsInterfaceSide s i, MonadCatch m) => ProtocolAction s m GenericObjectId + allocateObjectId = do + st <- State.get + let + id = st.nextId + nextId' = id + 1 + + when (nextId' == maximumId @s) $ throwM MaximumIdReached + State.put $ st {nextId = nextId'} + pure id + + -- | Sends a message without checking any ids or creating proxy objects objects. sendMessage :: forall s m i. (IsInterfaceSide s i, MonadCatch m) => Object s m i -> Up s i -> ProtocolStep s m () sendMessage object message = protocolStep $ sendMessageInternal object message @@ -445,7 +499,7 @@ getMessageAction object@(Object _ objectHandler) opcode = do message <- getDown object opcode pure $ handleMessage objectHandler object message -type RawMessage = (ObjectId, Opcode, BSL.ByteString) +type RawMessage = (GenericObjectId, Opcode, BSL.ByteString) receiveRawMessage :: forall s m. MonadCatch m => ProtocolAction s m (Maybe RawMessage) receiveRawMessage = do diff --git a/src/Quasar/Wayland/Protocol/TH.hs b/src/Quasar/Wayland/Protocol/TH.hs index 8be2a10..cb94e96 100644 --- a/src/Quasar/Wayland/Protocol/TH.hs +++ b/src/Quasar/Wayland/Protocol/TH.hs @@ -89,9 +89,10 @@ interfaceDecs interface = do iName = interfaceN interface iT = interfaceT interface instanceDecs = [ - valD (varP 'interfaceName) (normalB (stringE interface.name)) [], tySynInstD (tySynEqn Nothing (appT (conT ''Request) iT) rT), - tySynInstD (tySynEqn Nothing (appT (conT ''Event) iT) eT) + tySynInstD (tySynEqn Nothing (appT (conT ''Event) iT) eT), + tySynInstD (tySynEqn Nothing (appT (conT ''InterfaceName) iT) (litT (strTyLit interface.name))), + valD (varP 'interfaceName) (normalB (stringE interface.name)) [] ] rT :: Q Type rT = if length interface.requests > 0 then conT rTypeName else [t|Void|] @@ -149,6 +150,15 @@ interfaceDecs interface = do applyArgTypes :: Q Type -> Q Type applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType <$> msg.msgSpec.arguments) +interfaceSideInstanceDs :: InterfaceSpec -> Q [Dec] +interfaceSideInstanceDs interface = execWriterT do + tellQs [d|instance IsInterfaceSide 'Client $iT|] + tellQs [d|instance IsInterfaceSide 'Server $iT|] + --tellQs [d|instance forall m a. IsInterfaceHandler 'Client m $iT a where {handleMessage = undefined}|] + --tellQs [d|instance forall m a. IsInterfaceHandler 'Server m $iT a where {handleMessage = undefined}|] + where + iT = interfaceT interface + interfaceN :: InterfaceSpec -> Name interfaceN interface = mkName $ "I_" <> interface.name @@ -272,12 +282,14 @@ argumentSpecType :: ArgumentSpec -> Q Type argumentSpecType argSpec = promoteArgumentSpecType argSpec.argType promoteArgumentSpecType :: ArgumentType -> Q Type +promoteArgumentSpecType (ObjectArgument iName) = [t|ObjectId $(litT $ strTyLit iName)|] +promoteArgumentSpecType (NewIdArgument iName) = [t|NewId $(litT $ strTyLit iName)|] promoteArgumentSpecType arg = do argExp <- (TH.lift arg) - ConT <$> matchCon argExp + matchCon argExp where - matchCon :: Exp -> Q Name - matchCon (ConE name) = pure name + matchCon :: Exp -> Q Type + matchCon (ConE name) = pure $ ConT name matchCon (AppE x _) = matchCon x matchCon _ = fail "Can only promote ConE expression" @@ -310,8 +322,8 @@ parseInterface :: MonadFail m => Element -> m InterfaceSpec parseInterface element = do name <- getAttr "name" element version <- read <$> getAttr "version" element - requests <- mapM parseRequest $ zip [0..] $ findChildren (qname "request") element - events <- mapM parseEvent $ zip [0..] $ findChildren (qname "event") element + requests <- mapM (parseRequest name) $ zip [0..] $ findChildren (qname "request") element + events <- mapM (parseEvent name) $ zip [0..] $ findChildren (qname "event") element pure InterfaceSpec { name, version, @@ -319,17 +331,21 @@ parseInterface element = do events } -parseRequest :: MonadFail m => (Opcode, Element) -> m RequestSpec -parseRequest x = RequestSpec <$> parseMessage x +parseRequest :: MonadFail m => String -> (Opcode, Element) -> m RequestSpec +parseRequest x y = RequestSpec <$> parseMessage x y -parseEvent :: MonadFail m => (Opcode, Element) -> m EventSpec -parseEvent x = EventSpec <$> parseMessage x +parseEvent :: MonadFail m => String -> (Opcode, Element) -> m EventSpec +parseEvent x y = EventSpec <$> parseMessage x y -parseMessage :: MonadFail m => (Opcode, Element) -> m MessageSpec -parseMessage (opcode, element) = do +parseMessage :: MonadFail m => String -> (Opcode, Element) -> m MessageSpec +parseMessage interfaceName (opcode, element) = do name <- getAttr "name" element since <- read <<$>> peekAttr "since" element arguments <- mapM parseArgument $ zip [0..] $ findChildren (qname "arg") element + forM_ arguments \arg -> + when + do arg.argType == GenericNewIdArgument && interfaceName /= "wl_registry" + do fail $ "Invalid GenericNewIdArgument encountered on " <> interfaceName <> "." <> name <> " (only valid on wl_registry)" pure MessageSpec { name, since, @@ -357,9 +373,9 @@ parseArgument (index, element) = do parseArgumentType "string" Nothing = pure StringArgument parseArgumentType "array" Nothing = pure ArrayArgument parseArgumentType "object" (Just interface) = pure (ObjectArgument interface) - parseArgumentType "object" Nothing = pure UnknownObjectArgument + parseArgumentType "object" Nothing = pure GenericObjectArgument parseArgumentType "new_id" (Just interface) = pure (NewIdArgument interface) - parseArgumentType "new_id" Nothing = pure UnknownNewIdArgument + parseArgumentType "new_id" Nothing = pure GenericNewIdArgument parseArgumentType "fd" Nothing = pure FdArgument parseArgumentType x Nothing = fail $ "Unknown argument type \"" <> x <> "\" encountered" parseArgumentType x _ = fail $ "Argument type \"" <> x <> "\" should not have \"interface\" attribute" -- GitLab