diff --git a/src/Quasar/Wayland/Client.hs b/src/Quasar/Wayland/Client.hs index 246548240fb2da1d01c2a57c491e3c86711300f0..7e8bccbeb7141e8cb992e6b0ae1201c540965d2b 100644 --- a/src/Quasar/Wayland/Client.hs +++ b/src/Quasar/Wayland/Client.hs @@ -30,8 +30,8 @@ newWaylandClient :: MonadResourceManager m => Socket -> m WaylandClient newWaylandClient socket = do (connection, wlDisplay) <- newWaylandConnection @I_wl_display (traceCallback ignoreMessage) socket - (_wlRegistry, newId) <- stepProtocol connection $ newObject @'Client @I_wl_registry (traceCallback ignoreMessage) - stepProtocol connection $ sendMessage wlDisplay $ R_wl_display_get_registry newId + (_wlRegistry, newId) <- runProtocolM connection.protocolHandle $ newObject @'Client @I_wl_registry (traceCallback ignoreMessage) + runProtocolM connection.protocolHandle $ sendMessage wlDisplay $ R_wl_display_get_registry newId pure $ WaylandClient connection wlDisplay connectWaylandClient :: MonadResourceManager m => m WaylandClient diff --git a/src/Quasar/Wayland/Connection.hs b/src/Quasar/Wayland/Connection.hs index 4a401e3a8714bf40088628b7ac11bd978b2f98e8..6013a7ad8a23bf011e3bcf0185d1ea6cee8f9dbf 100644 --- a/src/Quasar/Wayland/Connection.hs +++ b/src/Quasar/Wayland/Connection.hs @@ -1,7 +1,6 @@ module Quasar.Wayland.Connection ( - WaylandConnection, + WaylandConnection(protocolHandle), newWaylandConnection, - stepProtocol, ) where import Control.Concurrent.STM @@ -19,8 +18,7 @@ import Quasar.Wayland.Protocol.Generated data WaylandConnection s = WaylandConnection { - protocolStateVar :: TVar (ProtocolState s), - outboxVar :: TMVar BSL.ByteString, + protocolHandle :: ProtocolHandle s, socket :: Socket, resourceManager :: ResourceManager } @@ -36,20 +34,18 @@ data SocketClosed = SocketClosed deriving anyclass Exception newWaylandConnection - :: forall wl_display wl_registry s m. (IsInterfaceSide s wl_display, MonadResourceManager m) + :: forall wl_display s m. (IsInterfaceSide s wl_display, MonadResourceManager m) => Callback s wl_display -> Socket -> m (WaylandConnection s, Object s wl_display) newWaylandConnection wlDisplayCallback socket = do - protocolStateVar <- liftIO $ newTVarIO protocolState - outboxVar <- liftIO newEmptyTMVarIO + (wlDisplay, protocolHandle) <- liftIO $ atomically $ initializeProtocol wlDisplayCallback pure resourceManager <- newResourceManager onResourceManager resourceManager do let connection = WaylandConnection { - protocolStateVar, - outboxVar, + protocolHandle, socket, resourceManager } @@ -61,20 +57,6 @@ newWaylandConnection wlDisplayCallback socket = do connectionThread connection $ receiveThread connection pure (connection, wlDisplay) - where - (protocolState, wlDisplay) = initialProtocolState wlDisplayCallback - -stepProtocol :: forall s m a. MonadIO m => WaylandConnection s -> ProtocolStep s a -> m a -stepProtocol connection step = liftIO do - result <- atomically do - oldState <- readTVar connection.protocolStateVar - (result, outBytes, newState) <- step oldState - writeTVar connection.protocolStateVar newState - mapM_ (putTMVar connection.outboxVar) outBytes - pure result - case result of - Left ex -> throwM (ex :: SomeException) - Right result -> pure result connectionThread :: MonadAsync m => WaylandConnection s -> IO () -> m () connectionThread connection work = async_ $ liftIO $ work `catches` [ignoreCancelTask, handleAll] @@ -84,7 +66,7 @@ connectionThread connection work = async_ $ liftIO $ work `catches` [ignoreCance sendThread :: WaylandConnection s -> IO () sendThread connection = forever do - bytes <- atomically $ takeTMVar connection.outboxVar + bytes <- takeOutbox connection.protocolHandle traceIO $ "Sending " <> show (BSL.length bytes) <> " bytes" SocketL.sendAll connection.socket bytes @@ -99,7 +81,7 @@ receiveThread connection = forever do traceIO $ "Received " <> show (BS.length bytes) <> " bytes" - stepProtocol connection $ feedInput bytes + feedInput connection.protocolHandle bytes closeConnection :: WaylandConnection s -> IO (Awaitable ()) closeConnection connection = do diff --git a/src/Quasar/Wayland/Protocol.hs b/src/Quasar/Wayland/Protocol.hs index 7e8d34b3e638eef054976aa93c161d5301596700..e770a3db6b7356107e6d68b373d722fbd3654956 100644 --- a/src/Quasar/Wayland/Protocol.hs +++ b/src/Quasar/Wayland/Protocol.hs @@ -12,24 +12,23 @@ import Quasar.Wayland.Protocol.Core import Quasar.Wayland.Protocol.Generated -createClientStateWithRegistry :: STM (ProtocolState 'Client) +createClientStateWithRegistry :: STM (ProtocolHandle 'Client) createClientStateWithRegistry = do - (wlRegistry, state') <- runStateT go initialState' - pure state' + (wlRegistry, protocolHandle) <- initializeProtocol wlDisplayCallback createRegistry + pure protocolHandle where - (initialState', wlDisplay) = initialProtocolState wlDisplayCallback - - go :: ProtocolAction 'Client (Object 'Client I_wl_registry) - go = do - (wlRegistry, newId) <- newObjectInternal @'Client @I_wl_registry (traceCallback ignoreMessage) - sendMessageInternal wlDisplay $ R_wl_display_get_registry newId + createRegistry :: Object 'Client I_wl_display -> ProtocolM 'Client (Object 'Client I_wl_registry) + createRegistry wlDisplay = do + (wlRegistry, newId) <- newObject @'Client @I_wl_registry (traceCallback ignoreMessage) + sendMessage wlDisplay $ R_wl_display_get_registry newId pure wlRegistry wlDisplayCallback :: IsInterfaceSide 'Client I_wl_display => Callback 'Client I_wl_display wlDisplayCallback = internalFnCallback handler where - handler :: Object 'Client I_wl_display -> E_wl_display -> ProtocolAction 'Client () + -- | wl_display is specified to never change, so manually specifying the callback is safe + handler :: Object 'Client I_wl_display -> E_wl_display -> ProtocolM 'Client () -- TODO parse oId handler _ (E_wl_display_error oId code message) = throwM $ ServerError code (toString message) handler _ (E_wl_display_delete_id deletedId) = pure () -- TODO confirm delete diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs index 73067ae272bcb797a7604d23647218432de61c3f..6cb097b157c03a127fca1bd85e5964f4cdb4d201 100644 --- a/src/Quasar/Wayland/Protocol/Core.hs +++ b/src/Quasar/Wayland/Protocol/Core.hs @@ -15,26 +15,36 @@ module Quasar.Wayland.Protocol.Core ( Object, IsObject, IsMessage(..), - ProtocolState, - ProtocolAction, + ProtocolHandle, + ProtocolM, + + -- * Protocol IO + initializeProtocol, + feedInput, + setException, + takeOutbox, + runProtocolM, + + -- * Low-level protocol interaction + sendMessage, + newObject, + + -- ** Callbacks Callback(..), internalFnCallback, traceCallback, ignoreMessage, - ProtocolStep, - initialProtocolState, - sendMessage, - newObject, - feedInput, - setException, - newObjectInternal, - sendMessageInternal, - - showObjectMessage, - isNewId, + -- * Protocol exceptions + CallbackFailed(..), + ParserFailed(..), + ProtocolException(..), + MaximumIdReached(..), ServerError(..), + -- * TH utilities + isNewId, + -- * Message decoder operations WireFormat(..), dropRemaining, @@ -44,7 +54,7 @@ module Quasar.Wayland.Protocol.Core ( import Control.Concurrent.STM import Control.Monad (replicateM_) import Control.Monad.Catch -import Control.Monad.State (StateT, runStateT) +import Control.Monad.Reader (ReaderT, runReaderT, ask, asks, lift) import Control.Monad.State qualified as State import Data.Binary import Data.Binary.Get @@ -222,7 +232,7 @@ putUp _ = putMessage @(Up s i) class IsInterfaceSide s i => IsInterfaceHandler s i a where - handleMessage :: a -> Object s i -> Down s i -> ProtocolAction s () + handleMessage :: a -> Object s i -> Down s i -> ProtocolM s () -- | Data kind @@ -293,29 +303,18 @@ showObjectMessage object message = objectInterfaceName object <> "@" <> show (objectId object) <> "." <> show message -data ProtocolState (s :: Side) = ProtocolState { - protocolException :: Maybe SomeException, - bytesReceived :: !Int64, - bytesSent :: !Int64, - inboxDecoder :: Decoder RawMessage, - outbox :: Maybe Put, - objects :: HashMap GenericObjectId (SomeObject s), - nextId :: Word32 -} - - data Callback s i = forall a. IsInterfaceHandler s i a => Callback a instance IsInterfaceSide s i => IsInterfaceHandler s i (Callback s i) where handleMessage (Callback callback) = handleMessage callback -data LowLevelCallback s i = IsInterfaceSide s i => FnCallback (Object s i -> Down s i -> ProtocolAction s ()) +data LowLevelCallback s i = IsInterfaceSide s i => FnCallback (Object s i -> Down s i -> ProtocolM s ()) instance IsInterfaceSide s i => IsInterfaceHandler s i (LowLevelCallback s i) where handleMessage (FnCallback fn) object msg = fn object msg -internalFnCallback :: IsInterfaceSide s i => (Object s i -> Down s i -> ProtocolAction s ()) -> Callback s i +internalFnCallback :: IsInterfaceSide s i => (Object s i -> Down s i -> ProtocolM s ()) -> Callback s i internalFnCallback = Callback . FnCallback @@ -359,113 +358,158 @@ data ServerError = ServerError Word32 String deriving stock Show deriving anyclass Exception --- * Monad plumbing +-- * Protocol state and monad plumbing -type ProtocolStep s a = ProtocolState s -> STM (Either SomeException a, Maybe BSL.ByteString, ProtocolState s) +-- | Top-level protocol handle (used e.g. to send/receive data) +newtype ProtocolHandle (s :: Side) = ProtocolHandle { + stateVar :: TVar (Either SomeException (ProtocolState s)) +} --- Must not be exported. 'ProtocolStep' ensures proper protocol failure in case of exceptions. -type ProtocolAction s a = StateT (ProtocolState s) STM a +-- | Protocol state handle, containing state for a non-failed protocol (should be kept in a 'ProtocolStateVar') +data ProtocolState (s :: Side) = ProtocolState { + bytesReceivedVar :: TVar Int64, + bytesSentVar :: TVar Int64, + inboxDecoderVar :: TVar (Decoder RawMessage), + outboxVar :: TVar (Maybe Put), + objectsVar :: TVar (HashMap GenericObjectId (SomeObject s)), + nextIdVar :: TVar Word32 +} -protocolStep :: forall s a. ProtocolAction s a -> ProtocolStep s a -protocolStep action inState = do - mapM_ throwM inState.protocolException - (result, (outbox, outState)) <- fmap takeOutbox . storeExceptionIfFailed <$> runStateT (try action) inState - pure (result, outbox, outState) - where - storeExceptionIfFailed :: (Either SomeException a, ProtocolState s) -> (Either SomeException a, ProtocolState s) - storeExceptionIfFailed (Left ex, st) = (Left ex, setException' ex st) - storeExceptionIfFailed x = x - setException' :: Exception e => e -> ProtocolState s -> ProtocolState s - setException' ex st = - if isJust st.protocolException - then st - else st{protocolException = Just (toException ex)} +type ProtocolM s a = ReaderT (ProtocolState s) STM a + +readProtocolVar :: (ProtocolState s -> TVar a) -> ProtocolM s a +readProtocolVar fn = do + state <- ask + lift $ readTVar (fn state) +writeProtocolVar :: (ProtocolState s -> TVar a) -> a -> ProtocolM s () +writeProtocolVar fn x = do + state <- ask + lift $ writeTVar (fn state) x --- * Exported functions +modifyProtocolVar :: (ProtocolState s -> TVar a) -> (a -> a) -> ProtocolM s () +modifyProtocolVar fn x = do + state <- ask + lift $ modifyTVar (fn state) x -initialProtocolState - :: forall wl_display wl_registry s. (IsInterfaceSide s wl_display) +stateProtocolVar :: (ProtocolState s -> TVar a) -> (a -> (r, a)) -> ProtocolM s r +stateProtocolVar fn x = do + state <- ask + lift $ stateTVar (fn state) x + +initializeProtocol + :: forall s wl_display a. (IsInterfaceSide s wl_display) => Callback s wl_display - -> (ProtocolState s, Object s wl_display) -initialProtocolState wlDisplayCallback = (initialState, wlDisplay) + -> (Object s wl_display -> ProtocolM s a) + -> STM (a, ProtocolHandle s) +initializeProtocol wlDisplayCallback initializationAction = do + bytesReceivedVar <- newTVar 0 + bytesSentVar <- newTVar 0 + inboxDecoderVar <- newTVar $ runGetIncremental getRawMessage + outboxVar <- newTVar Nothing + objectsVar <- newTVar $ HM.fromList [(1, (SomeObject wlDisplay))] + nextIdVar <- newTVar (initialId @s) + let state = ProtocolState { + bytesReceivedVar, + bytesSentVar, + inboxDecoderVar, + outboxVar, + objectsVar, + nextIdVar + } + stateVar <- newTVar (Right state) + let handle = ProtocolHandle { + stateVar + } + result <- runReaderT (initializationAction wlDisplay) state + pure (result, handle) where wlDisplay :: Object s wl_display wlDisplay = Object 1 wlDisplayCallback - initialState :: ProtocolState s - initialState = ProtocolState { - protocolException = Nothing, - bytesReceived = 0, - bytesSent = 0, - inboxDecoder = runGetIncremental getRawMessage, - outbox = Nothing, - objects = HM.fromList [(1, (SomeObject wlDisplay))], - nextId = initialId @s - } - --- | Feed the protocol newly received data -feedInput :: IsSide s => ByteString -> ProtocolStep s () -feedInput bytes = protocolStep do - feed + +-- | Entry point to run a protocol action, effectively an 'atomically' with correct error handling. +-- +-- Throws an exception, when the protocol reaches or is in a failed (/error) state. +runProtocolM :: (MonadIO m, MonadThrow m) => ProtocolHandle s -> ProtocolM s a -> m a +runProtocolM (ProtocolHandle stateVar) action = do + result <- liftIO $ atomically do + readTVar stateVar >>= \case + -- Protocol is already in a failed state + Left ex -> throwM ex + Right state -> do + -- Run action, catch exceptions + result <- runReaderT (try action) state + case result of + Left ex -> do + -- Action failed, change protocol state to failed + writeTVar stateVar (Left ex) + pure (Left ex) + Right result -> do + pure (Right result) + -- Transaction is committed, rethrow exception if the action failed + either throwM pure result + + + +-- | Feed the protocol newly received data. +feedInput :: (IsSide s, MonadIO m, MonadThrow m) => ProtocolHandle s -> ByteString -> m () +feedInput protocol bytes = runProtocolM protocol do + -- Exposing MonadIO instead of STM to the outside and using `runProtocolM` here enforces correct exception handling. + modifyProtocolVar (.bytesReceivedVar) (+ fromIntegral (BS.length bytes)) + modifyProtocolVar (.inboxDecoderVar) (`pushChunk` bytes) receiveMessages - where - feed = State.modify \st -> st { - bytesReceived = st.bytesReceived + fromIntegral (BS.length bytes), - inboxDecoder = pushChunk st.inboxDecoder bytes - } -setException :: Exception e => e -> ProtocolStep s () -setException ex = protocolStep do - State.modify \st -> st{protocolException = Just (toException ex)} +-- | Set the protocol to a failed state, e.g. when the socket closed unexpectedly. +setException :: (Exception e, MonadIO m, MonadThrow m) => ProtocolHandle s -> e -> m () +setException protocol ex = runProtocolM protocol $ throwM ex + +-- | Take data that has to be sent. Blocks until data is available. +takeOutbox :: (MonadIO m, MonadThrow m) => ProtocolHandle s -> m (BSL.ByteString) +takeOutbox protocol = runProtocolM protocol do + mOutboxData <- stateProtocolVar (.outboxVar) (\mOutboxData -> (mOutboxData, Nothing)) + outboxData <- maybe (lift retry) pure mOutboxData + let sendData = runPut outboxData + modifyProtocolVar (.bytesSentVar) (+ BSL.length sendData) + pure sendData -- | Create an object. The caller is responsible for sending the 'NewId' exactly once before using the object. newObject :: forall s i. IsInterfaceSide s i => Callback s i - -> ProtocolStep s (Object s i, NewId (InterfaceName i)) -newObject callback = protocolStep $ newObjectInternal callback - -newObjectInternal - :: forall s i. IsInterfaceSide s i - => Callback s i - -> ProtocolAction s (Object s i, NewId (InterfaceName i)) -newObjectInternal callback = do - genOId <- allocateObjectId @s + -> ProtocolM s (Object s i, NewId (InterfaceName i)) +newObject callback = do + genOId <- allocateObjectId let oId = NewId @(InterfaceName i) genOId object <- newObjectFromId oId callback pure (object, oId) where - allocateObjectId :: forall s. IsSide s => ProtocolAction s GenericObjectId + allocateObjectId :: ProtocolM s GenericObjectId allocateObjectId = do - st <- State.get - let - id = st.nextId - nextId' = id + 1 + id' <- readProtocolVar (.nextIdVar) + let nextId' = id' + 1 when (nextId' == maximumId @s) $ throwM MaximumIdReached - State.put $ st {nextId = nextId'} - pure id + + writeProtocolVar (.nextIdVar) nextId' + pure id' newObjectFromId :: forall s i. IsInterfaceSide s i => NewId (InterfaceName i) -> Callback s i - -> ProtocolAction s (Object s i) + -> ProtocolM s (Object s i) newObjectFromId (NewId oId) callback = do let object = Object oId callback someObject = SomeObject object - State.modify \st -> st { objects = HM.insert oId someObject st.objects} + modifyProtocolVar (.objectsVar) (HM.insert oId someObject) pure object --- | Sends a message without checking any ids or creating proxy objects objects. -sendMessage :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolStep s () -sendMessage object message = protocolStep $ sendMessageInternal object message - -sendMessageInternal :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolAction s () -sendMessageInternal object message = do +-- | Sends a message without checking any ids or creating proxy objects objects. (TODO) +sendMessage :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolM s () +sendMessage object message = do traceM $ "-> " <> showObjectMessage object message sendRawMessage messageWithHeader where @@ -483,25 +527,17 @@ sendMessageInternal object message = do msgSizeInteger :: Integer msgSizeInteger = 8 + fromIntegral (BSL.length body) --- | Take data that has to be sent (if available) -takeOutbox :: ProtocolState s -> (Maybe BSL.ByteString, ProtocolState s) -takeOutbox st = (maybeOutboxData, st{outbox = Nothing, bytesSent = st.bytesSent + outboxNumBytes}) - where - maybeOutboxData = if isJust st.protocolException then Nothing else outboxData - outboxData = runPut <$> st.outbox - outboxNumBytes = maybe 0 BSL.length maybeOutboxData - -receiveMessages :: IsSide s => ProtocolAction s () +receiveMessages :: IsSide s => ProtocolM s () receiveMessages = receiveRawMessage >>= \case Nothing -> pure () Just rawMessage -> do handleRawMessage rawMessage receiveMessages -handleRawMessage :: forall s. RawMessage -> ProtocolAction s () +handleRawMessage :: forall s. RawMessage -> ProtocolM s () handleRawMessage (oId, opcode, body) = do - objects <- State.gets (.objects) + objects <- readProtocolVar (.objectsVar) case HM.lookup oId objects of Nothing -> throwM $ ProtocolException $ "Received message with invalid object id " <> show oId @@ -520,23 +556,22 @@ getMessageAction :: IsInterfaceSide s i => Object s i -> Opcode - -> Get (ProtocolAction s ()) + -> Get (ProtocolM s ()) getMessageAction object@(Object _ objectHandler) opcode = do message <- getDown object opcode pure $ handleMessage objectHandler object message type RawMessage = (GenericObjectId, Opcode, BSL.ByteString) -receiveRawMessage :: forall s. ProtocolAction s (Maybe RawMessage) +receiveRawMessage :: forall s. ProtocolM s (Maybe RawMessage) receiveRawMessage = do - st <- State.get - (result, newDecoder) <- checkDecoder st.inboxDecoder - State.put st{inboxDecoder = newDecoder} + (result, nextDecoder) <- checkDecoder =<< readProtocolVar (.inboxDecoderVar) + writeProtocolVar (.inboxDecoderVar) nextDecoder pure result where checkDecoder :: Decoder RawMessage - -> ProtocolAction s (Maybe RawMessage, Decoder RawMessage) + -> ProtocolM s (Maybe RawMessage, Decoder RawMessage) checkDecoder (Fail _ _ message) = throwM (ParserFailed "RawMessage" message) checkDecoder x@(Partial _) = pure (Nothing, x) checkDecoder (Done leftovers _ result) = pure (Just result, pushChunk (runGetIncremental getRawMessage) leftovers) @@ -580,7 +615,5 @@ padding :: Integral a => a -> a padding size = ((4 - (size `mod` 4)) `mod` 4) -sendRawMessage :: Put -> ProtocolAction s () -sendRawMessage x = State.modify \st -> st { - outbox = Just (maybe x (<> x) st.outbox) -} +sendRawMessage :: Put -> ProtocolM s () +sendRawMessage x = modifyProtocolVar (.outboxVar) (Just . maybe x (<> x))