diff --git a/src/Quasar/Wayland/Client.hs b/src/Quasar/Wayland/Client.hs index 9e0b4101712fda43d36a882b4f14b22d3327352b..27a9159b42a81261c4286b89d0b09100752ed9c6 100644 --- a/src/Quasar/Wayland/Client.hs +++ b/src/Quasar/Wayland/Client.hs @@ -31,11 +31,11 @@ instance IsDisposable WaylandClient where toDisposable (WaylandClient connection) = toDisposable connection newWaylandClient :: MonadResourceManager m => Socket -> m WaylandClient -newWaylandClient socket = WaylandClient <$> newWaylandConnection wlDisplayCallback socket +newWaylandClient socket = WaylandClient <$> newWaylandConnection clientCallback clientCallback socket -wlDisplayCallback :: ClientCallback STM I_wl_display -wlDisplayCallback = Callback { - messageCallback = \_ _ -> lift $ traceM "Callback called" +clientCallback :: IsInterface i => ClientCallback STM i +clientCallback = Callback { + messageCallback = \x y -> lift $ traceM $ objectInterfaceName x <> "#" <> show (objectId x) <> "." <> messageName y } connectWaylandClient :: MonadResourceManager m => m WaylandClient diff --git a/src/Quasar/Wayland/Connection.hs b/src/Quasar/Wayland/Connection.hs index ebb020b45831275b2cc622b776544a080d327d78..359022359719dcdd9fe80e38ceb129aeec5e64b3 100644 --- a/src/Quasar/Wayland/Connection.hs +++ b/src/Quasar/Wayland/Connection.hs @@ -34,9 +34,9 @@ data SocketClosed = SocketClosed deriving stock Show deriving anyclass Exception -newWaylandConnection :: forall s m. MonadResourceManager m => Callback s STM I_wl_display -> Socket -> m (WaylandConnection s) -newWaylandConnection wlDisplayCallback socket = do - protocolStateVar <- liftIO $ newTVarIO $ initialProtocolState wlDisplayCallback +newWaylandConnection :: forall s m. MonadResourceManager m => Callback s STM I_wl_display -> Callback s STM I_wl_registry -> Socket -> m (WaylandConnection s) +newWaylandConnection wlDisplayCallback wlRegistryCallback socket = do + protocolStateVar <- liftIO $ newTVarIO $ initialProtocolState wlDisplayCallback wlRegistryCallback outboxVar <- liftIO newEmptyTMVarIO resourceManager <- newResourceManager diff --git a/src/Quasar/Wayland/Core.hs b/src/Quasar/Wayland/Core.hs index 7dc6c5dfa50d0e3137d3dcb6c5fa2b13253b2284..e144c2e2a7c5e9abe12f01544a23059dacb96b15 100644 --- a/src/Quasar/Wayland/Core.hs +++ b/src/Quasar/Wayland/Core.hs @@ -4,6 +4,7 @@ module Quasar.Wayland.Core ( IsInterface(..), Side(..), Object, + IsSomeObject(..), IsSomeObject, IsMessage(..), ProtocolState, @@ -21,6 +22,9 @@ module Quasar.Wayland.Core ( import Control.Monad (replicateM_) import Control.Monad.Catch +import Control.Monad.Catch.Pure +import Control.Monad.Reader (ReaderT, runReaderT) +import Control.Monad.Writer (WriterT, runWriterT) import Control.Monad.State (StateT, runStateT, lift) import Control.Monad.State qualified as State import Data.Binary @@ -48,44 +52,42 @@ newtype Fixed = Fixed Word32 class WireFormat a where type Argument a putArgument :: Argument a -> StateT (ProtocolState s m) PutM () - getArgument :: StateT (ProtocolState s m) Get (Argument a) + getArgument :: WireGet s m (Argument a) instance WireFormat "int" where type Argument "int" = Int32 putArgument = lift . putInt32host - getArgument = lift getInt32host + getArgument = liftGet getInt32host instance WireFormat "uint" where type Argument "uint" = Word32 putArgument = lift . putWord32host - getArgument = lift getWord32host + getArgument = liftGet getWord32host instance WireFormat "fixed" where type Argument "fixed" = Fixed putArgument (Fixed repr) = lift $ putWord32host repr - getArgument = lift $ Fixed <$> getWord32host + getArgument = liftGet $ Fixed <$> getWord32host instance WireFormat "string" where type Argument "string" = BS.ByteString putArgument = lift . putWaylandBlob - getArgument = lift getWaylandBlob + getArgument = liftGet getWaylandBlob -data WireObject s m i - -instance forall (s :: Side) m i. MonadCatch m => WireFormat (WireObject s m i) where - type Argument (WireObject s m i) = Object s m i +instance forall (s :: Side) m i. MonadCatch m => WireFormat (Object s m i) where + type Argument (Object s m i) = Object s m i putArgument = undefined getArgument = undefined -instance WireFormat "new_id" where - type Argument "new_id" = Void +instance WireFormat (NewId s m i) where + type Argument (NewId s m i) = (NewId s m i) putArgument = undefined getArgument = undefined instance WireFormat "array" where type Argument "array" = BS.ByteString putArgument = lift . putWaylandBlob - getArgument = lift getWaylandBlob + getArgument = liftGet getWaylandBlob instance WireFormat "fd" where type Argument "fd" = Void @@ -93,51 +95,70 @@ instance WireFormat "fd" where getArgument = undefined - -- | A wayland interface -class (Binary (TRequest i), Binary (TEvent i)) => IsInterface i where - type TRequest i - type TEvent i +class (IsMessage (Request i), Binary (Request i), IsMessage (Event i), Binary (Event i)) => IsInterface i where + type Request i + type Event i interfaceName :: String -class IsInterface i => IsObject (s :: Side) i where - type Up s i - type Down s i +type family Up (s :: Side) i where + Up 'Client i = Request i + Up 'Server i = Event i + +type family Down (s :: Side) i where + Down 'Client i = Event i + Down 'Server i = Request i +-- | Data kind data Side = Client | Server data Object s m i = IsInterface i => Object ObjectId (Callback s m i) -instance IsInterface i => IsObject 'Client i where - type Up 'Client i = TRequest i - type Down 'Client i = TEvent i - -instance IsInterface i => IsObject 'Server i where - type Up 'Server i = TEvent i - type Down 'Server i = TRequest i instance forall s m i. IsInterface i => IsSomeObject (Object s m i) where objectId (Object oId _) = oId objectInterfaceName _ = interfaceName @i -class IsSomeObject i where - objectId :: i -> ObjectId - objectInterfaceName :: i -> String +class IsSomeObject a where + objectId :: a -> ObjectId + objectInterfaceName :: a -> String -- | Wayland object quantification wrapper -data SomeObject = forall i. IsSomeObject i => SomeObject i +data SomeObject s m = forall i. IsInterface i => SomeObject (Object s m i) -instance IsSomeObject SomeObject where +instance IsSomeObject (SomeObject s m) where objectId (SomeObject object) = objectId object objectInterfaceName (SomeObject object) = objectInterfaceName object -class IsMessage i where - messageName :: i -> String +data NewId s m i = IsInterface i => NewId ObjectId + + +class IsMessage a where + messageName :: a -> String + getMessage :: IsInterface i => Object s m i -> Opcode -> WireGet s m a + putMessage :: a -> StateT (ProtocolState s m) PutM () instance IsMessage Void where messageName = absurd + getMessage = invalidOpcode + putMessage = absurd + +describeMessage + :: forall s m i. IsInterface i + => Object s m i + -> Opcode + -> BSL.ByteString + -> String +describeMessage object opcode body = + objectInterfaceName object <> "@" <> show (objectId object) <> + ".msg#" <> show opcode <> + " (" <> show (BSL.length body) <> "B)" + +invalidOpcode :: IsInterface i => Object s m i -> Opcode -> WireGet s m a +invalidOpcode object opcode = + throwM $ ProtocolException $ "Invalid opcode " <> show opcode <> " on " <> objectInterfaceName object <> "@" <> show (objectId object) -- TODO remove @@ -174,9 +195,9 @@ data ProtocolState (s :: Side) m = ProtocolState { protocolException :: Maybe SomeException, bytesReceived :: Word64, bytesSent :: Word64, - inboxDecoder :: Decoder (ObjectId, Opcode, BSL.ByteString), + inboxDecoder :: Decoder RawMessage, outbox :: Maybe Put, - objects :: HashMap ObjectId SomeObject + objects :: HashMap ObjectId (SomeObject s m) } @@ -193,7 +214,7 @@ data CallbackFailed = CallbackFailed SomeException deriving stock Show deriving anyclass Exception -data ParserFailed = ParserFailed String +data ParserFailed = ParserFailed String String deriving stock Show deriving anyclass Exception @@ -220,16 +241,24 @@ protocolStep action inState = do then st else st{protocolException = Just (toException ex)} +type WireGet s m a = ReaderT (HashMap ObjectId (SomeObject s m)) (WriterT [StateT (ProtocolState s m) m ()] (CatchT Get)) a + +liftGet :: Get a -> WireGet s m a +liftGet = lift . lift . lift + -- * Exported functions initialProtocolState - :: forall wl_display s m. IsInterface wl_display + :: forall wl_display wl_registry s m. (IsInterface wl_display, IsInterface wl_registry) => Callback s m wl_display + -> Callback s m wl_registry -> ProtocolState s m -initialProtocolState wlDisplayCallback = sendInitialMessage initialState +initialProtocolState wlDisplayCallback wlRegistryCallback = sendInitialMessage initialState where wlDisplay :: Object s m wl_display wlDisplay = Object 1 wlDisplayCallback + wlRegistry :: Object s m wl_registry + wlRegistry = Object 2 wlRegistryCallback initialState :: ProtocolState s m initialState = ProtocolState { protocolException = Nothing, @@ -237,7 +266,7 @@ initialProtocolState wlDisplayCallback = sendInitialMessage initialState bytesSent = 0, inboxDecoder = runGetIncremental getRawMessage, outbox = Nothing, - objects = HM.singleton 1 (SomeObject wlDisplay) + objects = HM.fromList [(1, (SomeObject wlDisplay)), (2, (SomeObject wlRegistry))] } -- | Feed the protocol newly received data @@ -275,32 +304,54 @@ sendInitialMessage = sendMessageInternal 1 1 [NewIdArgument 2] runCallbacks :: MonadCatch m => StateT (ProtocolState s m) m () runCallbacks = receiveRawMessage >>= \case Nothing -> pure () - Just message -> do - handleMessage message + Just rawMessage -> do + handleMessage rawMessage runCallbacks -handleMessage :: MonadCatch m => RawMessage -> StateT (ProtocolState s m) m () -handleMessage (oId, opcode, body) = do +handleMessage :: forall s m. MonadCatch m => RawMessage -> StateT (ProtocolState s m) m () +handleMessage rawMessage@(oId, opcode, body) = do st <- State.get case HM.lookup oId st.objects of Nothing -> throwM $ ProtocolException $ "Received message with invalid object id " <> show oId - Just object -> traceM (objectInterfaceName object) + Just (SomeObject object) -> do + case runGetOrFail (getMessageAction st.objects object rawMessage) body of + Left (_, _, message) -> + throwM $ ParserFailed (describeMessage object opcode body) message + Right ("", _, result) -> + traceM $ "Received message " <> (describeMessage object opcode body) + Right (leftovers, _, _) -> + throwM $ ParserFailed (describeMessage object opcode body) (show (BSL.length leftovers) <> "B not parsed") + +getMessageAction + :: MonadCatch m + => HashMap ObjectId (SomeObject s m) + -> Object s m i + -> RawMessage + -> Get (ProtocolAction s m ()) +getMessageAction objects object@(Object _ callback) (oId, opcode, body) = do + pure $ traceM $ "Received message " <> objectInterfaceName object <> "@" <> show oId <> ".msg#" <> show opcode <> " (" <> show (BSL.length body) <> "B)" + + +type ProtocolAction s m a = StateT (ProtocolState s m) m a type RawMessage = (ObjectId, Opcode, BSL.ByteString) -receiveRawMessage :: MonadCatch m => StateT (ProtocolState s m) m (Maybe RawMessage) +receiveRawMessage :: forall s m a. MonadCatch m => StateT (ProtocolState s m) m (Maybe RawMessage) receiveRawMessage = do st <- State.get (result, newDecoder) <- checkDecoder st.inboxDecoder State.put st{inboxDecoder = newDecoder} - pure result where - checkDecoder :: MonadCatch m => Decoder RawMessage -> StateT (ProtocolState s m) m (Maybe RawMessage, Decoder RawMessage) - checkDecoder (Fail _ _ message) = throwM (ParserFailed message) + checkDecoder + :: MonadCatch m + => Decoder RawMessage + -> StateT (ProtocolState s m) m (Maybe RawMessage, Decoder RawMessage) + checkDecoder (Fail _ _ message) = throwM (ParserFailed "RawMessage" message) checkDecoder x@(Partial _) = pure (Nothing, x) checkDecoder (Done leftovers _ result) = pure (Just result, pushChunk (runGetIncremental getRawMessage) leftovers) + getRawMessage :: Get RawMessage getRawMessage = do oId <- getWord32host