From b43d66a94c48d9d1091467ca6dc416c7f1698cca Mon Sep 17 00:00:00 2001 From: Jens Nolte <git@queezle.net> Date: Thu, 16 Sep 2021 14:39:36 +0200 Subject: [PATCH] Pin the protocol monad to STM I really like the idea of a pure wire protocol implementation and might revisit that later, but to get an MVP it's much easier to use STM in the core, since all interface impementations are going to use it anyway. --- src/Quasar/Wayland/Client.hs | 4 +- src/Quasar/Wayland/Connection.hs | 8 +- src/Quasar/Wayland/Protocol.hs | 11 +-- src/Quasar/Wayland/Protocol/Core.hs | 111 ++++++++++++++-------------- 4 files changed, 68 insertions(+), 66 deletions(-) diff --git a/src/Quasar/Wayland/Client.hs b/src/Quasar/Wayland/Client.hs index 6f8f2cf..2465482 100644 --- a/src/Quasar/Wayland/Client.hs +++ b/src/Quasar/Wayland/Client.hs @@ -18,7 +18,7 @@ import System.FilePath ((</>), isRelative) import Text.Read (readEither) -data WaylandClient = WaylandClient (WaylandConnection 'Client) (Object 'Client STM I_wl_display) +data WaylandClient = WaylandClient (WaylandConnection 'Client) (Object 'Client I_wl_display) instance IsResourceManager WaylandClient where toResourceManager (WaylandClient connection _) = toResourceManager connection @@ -30,7 +30,7 @@ newWaylandClient :: MonadResourceManager m => Socket -> m WaylandClient newWaylandClient socket = do (connection, wlDisplay) <- newWaylandConnection @I_wl_display (traceCallback ignoreMessage) socket - (_wlRegistry, newId) <- stepProtocol connection $ newObject @'Client @STM @I_wl_registry (traceCallback ignoreMessage) + (_wlRegistry, newId) <- stepProtocol connection $ newObject @'Client @I_wl_registry (traceCallback ignoreMessage) stepProtocol connection $ sendMessage wlDisplay $ R_wl_display_get_registry newId pure $ WaylandClient connection wlDisplay diff --git a/src/Quasar/Wayland/Connection.hs b/src/Quasar/Wayland/Connection.hs index f5b0b99..4a401e3 100644 --- a/src/Quasar/Wayland/Connection.hs +++ b/src/Quasar/Wayland/Connection.hs @@ -19,7 +19,7 @@ import Quasar.Wayland.Protocol.Generated data WaylandConnection s = WaylandConnection { - protocolStateVar :: TVar (ProtocolState s STM), + protocolStateVar :: TVar (ProtocolState s), outboxVar :: TMVar BSL.ByteString, socket :: Socket, resourceManager :: ResourceManager @@ -37,9 +37,9 @@ data SocketClosed = SocketClosed newWaylandConnection :: forall wl_display wl_registry s m. (IsInterfaceSide s wl_display, MonadResourceManager m) - => Callback s STM wl_display + => Callback s wl_display -> Socket - -> m (WaylandConnection s, Object s STM wl_display) + -> m (WaylandConnection s, Object s wl_display) newWaylandConnection wlDisplayCallback socket = do protocolStateVar <- liftIO $ newTVarIO protocolState outboxVar <- liftIO newEmptyTMVarIO @@ -64,7 +64,7 @@ newWaylandConnection wlDisplayCallback socket = do where (protocolState, wlDisplay) = initialProtocolState wlDisplayCallback -stepProtocol :: forall s m a. MonadIO m => WaylandConnection s -> ProtocolStep s STM a -> m a +stepProtocol :: forall s m a. MonadIO m => WaylandConnection s -> ProtocolStep s a -> m a stepProtocol connection step = liftIO do result <- atomically do oldState <- readTVar connection.protocolStateVar diff --git a/src/Quasar/Wayland/Protocol.hs b/src/Quasar/Wayland/Protocol.hs index 03dd836..7e8d34b 100644 --- a/src/Quasar/Wayland/Protocol.hs +++ b/src/Quasar/Wayland/Protocol.hs @@ -3,6 +3,7 @@ module Quasar.Wayland.Protocol ( createClientStateWithRegistry ) where +import Control.Concurrent.STM import Control.Monad.Catch import Control.Monad.State (StateT, runStateT) import Data.ByteString.UTF8 (toString) @@ -11,24 +12,24 @@ import Quasar.Wayland.Protocol.Core import Quasar.Wayland.Protocol.Generated -createClientStateWithRegistry :: forall m. MonadCatch m => m (ProtocolState 'Client m) +createClientStateWithRegistry :: STM (ProtocolState 'Client) createClientStateWithRegistry = do (wlRegistry, state') <- runStateT go initialState' pure state' where (initialState', wlDisplay) = initialProtocolState wlDisplayCallback - go :: ProtocolAction 'Client m (Object 'Client m I_wl_registry) + go :: ProtocolAction 'Client (Object 'Client I_wl_registry) go = do - (wlRegistry, newId) <- newObjectInternal @'Client @m @I_wl_registry (traceCallback ignoreMessage) + (wlRegistry, newId) <- newObjectInternal @'Client @I_wl_registry (traceCallback ignoreMessage) sendMessageInternal wlDisplay $ R_wl_display_get_registry newId pure wlRegistry - wlDisplayCallback :: forall m. (IsInterfaceSide 'Client I_wl_display, MonadCatch m) => Callback 'Client m I_wl_display + wlDisplayCallback :: IsInterfaceSide 'Client I_wl_display => Callback 'Client I_wl_display wlDisplayCallback = internalFnCallback handler where - handler :: Object 'Client m I_wl_display -> E_wl_display -> ProtocolAction 'Client m () + handler :: Object 'Client I_wl_display -> E_wl_display -> ProtocolAction 'Client () -- TODO parse oId handler _ (E_wl_display_error oId code message) = throwM $ ServerError code (toString message) handler _ (E_wl_display_delete_id deletedId) = pure () -- TODO confirm delete diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs index a8e355a..73067ae 100644 --- a/src/Quasar/Wayland/Protocol/Core.hs +++ b/src/Quasar/Wayland/Protocol/Core.hs @@ -41,6 +41,7 @@ module Quasar.Wayland.Protocol.Core ( invalidOpcode, ) where +import Control.Concurrent.STM import Control.Monad (replicateM_) import Control.Monad.Catch import Control.Monad.State (StateT, runStateT) @@ -213,21 +214,21 @@ class ( => IsInterfaceSide (s :: Side) i -getDown :: forall s m i. IsInterfaceSide s i => Object s m i -> Opcode -> Get (Down s i) +getDown :: forall s i. IsInterfaceSide s i => Object s i -> Opcode -> Get (Down s i) getDown = getMessage @(Down s i) -putUp :: forall s m i. IsInterfaceSide s i => Object s m i -> Up s i -> PutM Opcode +putUp :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> PutM Opcode putUp _ = putMessage @(Up s i) -class IsInterfaceSide s i => IsInterfaceHandler s m i a where - handleMessage :: a -> Object s m i -> Down s i -> ProtocolAction s m () +class IsInterfaceSide s i => IsInterfaceHandler s i a where + handleMessage :: a -> Object s i -> Down s i -> ProtocolAction s () -- | Data kind data Side = Client | Server -data Object s m i = IsInterfaceSide s i => Object GenericObjectId (Callback s m i) +data Object s i = IsInterfaceSide s i => Object GenericObjectId (Callback s i) class IsObject a where objectId :: a -> GenericObjectId @@ -237,11 +238,11 @@ class IsObjectSide a where describeUpMessage :: a -> Opcode -> BSL.ByteString -> String describeDownMessage :: a -> Opcode -> BSL.ByteString -> String -instance forall s m i. IsInterface i => IsObject (Object s m i) where +instance forall s i. IsInterface i => IsObject (Object s i) where objectId (Object oId _) = oId objectInterfaceName _ = interfaceName @i -instance forall s m i. IsInterfaceSide s i => IsObjectSide (Object s m i) where +instance forall s i. IsInterfaceSide s i => IsObjectSide (Object s i) where describeUpMessage object opcode body = objectInterfaceName object <> "@" <> show (objectId object) <> "." <> fromMaybe "[invalidOpcode]" (opcodeName @(Up s i) opcode) <> @@ -252,17 +253,17 @@ instance forall s m i. IsInterfaceSide s i => IsObjectSide (Object s m i) where " (" <> show (BSL.length body) <> "B)" -- | Wayland object quantification wrapper -data SomeObject s m - = forall i. IsInterfaceSide s i => SomeObject (Object s m i) +data SomeObject s + = forall i. IsInterfaceSide s i => SomeObject (Object s i) | UnknownObject String GenericObjectId -instance IsObject (SomeObject s m) where +instance IsObject (SomeObject s) where objectId (SomeObject object) = objectId object objectId (UnknownObject _ oId) = oId objectInterfaceName (SomeObject object) = objectInterfaceName object objectInterfaceName (UnknownObject interface _) = interface -instance IsObjectSide (SomeObject s m) where +instance IsObjectSide (SomeObject s) where describeUpMessage (SomeObject object) = describeUpMessage object describeUpMessage (UnknownObject interface oId) = \opcode body -> interface <> "@" <> show oId <> ".#" <> show opcode <> @@ -275,7 +276,7 @@ instance IsObjectSide (SomeObject s m) where class (Eq a, Show a) => IsMessage a where opcodeName :: Opcode -> Maybe String - getMessage :: IsInterface i => Object s m i -> Opcode -> Get a + getMessage :: IsInterface i => Object s i -> Opcode -> Get a putMessage :: a -> PutM Opcode instance IsMessage Void where @@ -283,7 +284,7 @@ instance IsMessage Void where getMessage = invalidOpcode putMessage = absurd -invalidOpcode :: IsInterface i => Object s m i -> Opcode -> Get a +invalidOpcode :: IsInterface i => Object s i -> Opcode -> Get a invalidOpcode object opcode = fail $ "Invalid opcode " <> show opcode <> " on " <> objectInterfaceName object <> "@" <> show (objectId object) @@ -292,29 +293,29 @@ showObjectMessage object message = objectInterfaceName object <> "@" <> show (objectId object) <> "." <> show message -data ProtocolState (s :: Side) m = ProtocolState { +data ProtocolState (s :: Side) = ProtocolState { protocolException :: Maybe SomeException, bytesReceived :: !Int64, bytesSent :: !Int64, inboxDecoder :: Decoder RawMessage, outbox :: Maybe Put, - objects :: HashMap GenericObjectId (SomeObject s m), + objects :: HashMap GenericObjectId (SomeObject s), nextId :: Word32 } -data Callback s m i = forall a. IsInterfaceHandler s m i a => Callback a +data Callback s i = forall a. IsInterfaceHandler s i a => Callback a -instance IsInterfaceSide s i => IsInterfaceHandler s m i (Callback s m i) where +instance IsInterfaceSide s i => IsInterfaceHandler s i (Callback s i) where handleMessage (Callback callback) = handleMessage callback -data LowLevelCallback s m i = IsInterfaceSide s i => FnCallback (Object s m i -> Down s i -> ProtocolAction s m ()) +data LowLevelCallback s i = IsInterfaceSide s i => FnCallback (Object s i -> Down s i -> ProtocolAction s ()) -instance IsInterfaceSide s i => IsInterfaceHandler s m i (LowLevelCallback s m i) where +instance IsInterfaceSide s i => IsInterfaceHandler s i (LowLevelCallback s i) where handleMessage (FnCallback fn) object msg = fn object msg -internalFnCallback :: IsInterfaceSide s i => (Object s m i -> Down s i -> ProtocolAction s m ()) -> Callback s m i +internalFnCallback :: IsInterfaceSide s i => (Object s i -> Down s i -> ProtocolAction s ()) -> Callback s i internalFnCallback = Callback . FnCallback @@ -327,13 +328,13 @@ internalFnCallback = Callback . FnCallback -- trace message. -- -- Uses `traceM` internally. -traceCallback :: (IsInterfaceSide 'Client i, Monad m) => Callback 'Client m i -> Callback 'Client m i +traceCallback :: IsInterfaceSide 'Client i => Callback 'Client i -> Callback 'Client i traceCallback next = internalFnCallback \object message -> do traceM $ "<- " <> showObjectMessage object message handleMessage next object message -- | A `Callback` that ignores all messages. Intended for development purposes, e.g. together with `traceCallback`. -ignoreMessage :: (IsInterfaceSide 'Client i, Monad m) => Callback 'Client m i +ignoreMessage :: IsInterfaceSide 'Client i => Callback 'Client i ignoreMessage = internalFnCallback \_ _ -> pure () -- * Exceptions @@ -360,21 +361,21 @@ data ServerError = ServerError Word32 String -- * Monad plumbing -type ProtocolStep s m a = ProtocolState s m -> m (Either SomeException a, Maybe BSL.ByteString, ProtocolState s m) +type ProtocolStep s a = ProtocolState s -> STM (Either SomeException a, Maybe BSL.ByteString, ProtocolState s) -- Must not be exported. 'ProtocolStep' ensures proper protocol failure in case of exceptions. -type ProtocolAction s m a = StateT (ProtocolState s m) m a +type ProtocolAction s a = StateT (ProtocolState s) STM a -protocolStep :: forall s m a. MonadCatch m => ProtocolAction s m a -> ProtocolStep s m a +protocolStep :: forall s a. ProtocolAction s a -> ProtocolStep s a protocolStep action inState = do mapM_ throwM inState.protocolException (result, (outbox, outState)) <- fmap takeOutbox . storeExceptionIfFailed <$> runStateT (try action) inState pure (result, outbox, outState) where - storeExceptionIfFailed :: (Either SomeException a, ProtocolState s m) -> (Either SomeException a, ProtocolState s m) + storeExceptionIfFailed :: (Either SomeException a, ProtocolState s) -> (Either SomeException a, ProtocolState s) storeExceptionIfFailed (Left ex, st) = (Left ex, setException' ex st) storeExceptionIfFailed x = x - setException' :: Exception e => e -> (ProtocolState s m) -> (ProtocolState s m) + setException' :: Exception e => e -> ProtocolState s -> ProtocolState s setException' ex st = if isJust st.protocolException then st @@ -384,14 +385,14 @@ protocolStep action inState = do -- * Exported functions initialProtocolState - :: forall wl_display wl_registry s m. IsInterfaceSide s wl_display - => Callback s m wl_display - -> (ProtocolState s m, Object s m wl_display) + :: forall wl_display wl_registry s. (IsInterfaceSide s wl_display) + => Callback s wl_display + -> (ProtocolState s, Object s wl_display) initialProtocolState wlDisplayCallback = (initialState, wlDisplay) where - wlDisplay :: Object s m wl_display + wlDisplay :: Object s wl_display wlDisplay = Object 1 wlDisplayCallback - initialState :: ProtocolState s m + initialState :: ProtocolState s initialState = ProtocolState { protocolException = Nothing, bytesReceived = 0, @@ -403,7 +404,7 @@ initialProtocolState wlDisplayCallback = (initialState, wlDisplay) } -- | Feed the protocol newly received data -feedInput :: (IsSide s, MonadCatch m) => ByteString -> ProtocolStep s m () +feedInput :: IsSide s => ByteString -> ProtocolStep s () feedInput bytes = protocolStep do feed receiveMessages @@ -413,29 +414,29 @@ feedInput bytes = protocolStep do inboxDecoder = pushChunk st.inboxDecoder bytes } -setException :: (MonadCatch m, Exception e) => e -> ProtocolStep s m () +setException :: Exception e => e -> ProtocolStep s () setException ex = protocolStep do State.modify \st -> st{protocolException = Just (toException ex)} -- | Create an object. The caller is responsible for sending the 'NewId' exactly once before using the object. newObject - :: forall s m i. (IsInterfaceSide s i, MonadCatch m) - => Callback s m i - -> ProtocolStep s m (Object s m i, NewId (InterfaceName i)) + :: forall s i. IsInterfaceSide s i + => Callback s i + -> ProtocolStep s (Object s i, NewId (InterfaceName i)) newObject callback = protocolStep $ newObjectInternal callback newObjectInternal - :: forall s m i. (IsInterfaceSide s i, MonadCatch m) - => Callback s m i - -> ProtocolAction s m (Object s m i, NewId (InterfaceName i)) + :: forall s i. IsInterfaceSide s i + => Callback s i + -> ProtocolAction s (Object s i, NewId (InterfaceName i)) newObjectInternal callback = do - genOId <- allocateObjectId @s @m + genOId <- allocateObjectId @s let oId = NewId @(InterfaceName i) genOId object <- newObjectFromId oId callback pure (object, oId) where - allocateObjectId :: forall s m. (IsSide s, MonadCatch m) => ProtocolAction s m GenericObjectId + allocateObjectId :: forall s. IsSide s => ProtocolAction s GenericObjectId allocateObjectId = do st <- State.get let @@ -447,10 +448,10 @@ newObjectInternal callback = do pure id newObjectFromId - :: forall s m i. (IsInterfaceSide s i, MonadCatch m) + :: forall s i. IsInterfaceSide s i => NewId (InterfaceName i) - -> Callback s m i - -> ProtocolAction s m (Object s m i) + -> Callback s i + -> ProtocolAction s (Object s i) newObjectFromId (NewId oId) callback = do let object = Object oId callback @@ -460,10 +461,10 @@ newObjectFromId (NewId oId) callback = do -- | Sends a message without checking any ids or creating proxy objects objects. -sendMessage :: forall s m i. (IsInterfaceSide s i, MonadCatch m) => Object s m i -> Up s i -> ProtocolStep s m () +sendMessage :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolStep s () sendMessage object message = protocolStep $ sendMessageInternal object message -sendMessageInternal :: forall s m i. (IsInterfaceSide s i, MonadCatch m) => Object s m i -> Up s i -> ProtocolAction s m () +sendMessageInternal :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolAction s () sendMessageInternal object message = do traceM $ "-> " <> showObjectMessage object message sendRawMessage messageWithHeader @@ -483,7 +484,7 @@ sendMessageInternal object message = do msgSizeInteger = 8 + fromIntegral (BSL.length body) -- | Take data that has to be sent (if available) -takeOutbox :: ProtocolState s m -> (Maybe BSL.ByteString, ProtocolState s m) +takeOutbox :: ProtocolState s -> (Maybe BSL.ByteString, ProtocolState s) takeOutbox st = (maybeOutboxData, st{outbox = Nothing, bytesSent = st.bytesSent + outboxNumBytes}) where maybeOutboxData = if isJust st.protocolException then Nothing else outboxData @@ -491,14 +492,14 @@ takeOutbox st = (maybeOutboxData, st{outbox = Nothing, bytesSent = st.bytesSent outboxNumBytes = maybe 0 BSL.length maybeOutboxData -receiveMessages :: (IsSide s, MonadCatch m) => ProtocolAction s m () +receiveMessages :: IsSide s => ProtocolAction s () receiveMessages = receiveRawMessage >>= \case Nothing -> pure () Just rawMessage -> do handleRawMessage rawMessage receiveMessages -handleRawMessage :: forall s m. MonadCatch m => RawMessage -> ProtocolAction s m () +handleRawMessage :: forall s. RawMessage -> ProtocolAction s () handleRawMessage (oId, opcode, body) = do objects <- State.gets (.objects) case HM.lookup oId objects of @@ -517,16 +518,16 @@ handleRawMessage (oId, opcode, body) = do getMessageAction :: IsInterfaceSide s i - => Object s m i + => Object s i -> Opcode - -> Get (ProtocolAction s m ()) + -> Get (ProtocolAction s ()) getMessageAction object@(Object _ objectHandler) opcode = do message <- getDown object opcode pure $ handleMessage objectHandler object message type RawMessage = (GenericObjectId, Opcode, BSL.ByteString) -receiveRawMessage :: forall s m. MonadCatch m => ProtocolAction s m (Maybe RawMessage) +receiveRawMessage :: forall s. ProtocolAction s (Maybe RawMessage) receiveRawMessage = do st <- State.get (result, newDecoder) <- checkDecoder st.inboxDecoder @@ -535,7 +536,7 @@ receiveRawMessage = do where checkDecoder :: Decoder RawMessage - -> ProtocolAction s m (Maybe RawMessage, Decoder RawMessage) + -> ProtocolAction s (Maybe RawMessage, Decoder RawMessage) checkDecoder (Fail _ _ message) = throwM (ParserFailed "RawMessage" message) checkDecoder x@(Partial _) = pure (Nothing, x) checkDecoder (Done leftovers _ result) = pure (Just result, pushChunk (runGetIncremental getRawMessage) leftovers) @@ -579,7 +580,7 @@ padding :: Integral a => a -> a padding size = ((4 - (size `mod` 4)) `mod` 4) -sendRawMessage :: MonadCatch m => Put -> ProtocolAction s m () +sendRawMessage :: Put -> ProtocolAction s () sendRawMessage x = State.modify \st -> st { outbox = Just (maybe x (<> x) st.outbox) } -- GitLab