diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs index 1cca4864f5ee04c8892ae3f90809df1fde7561d6..69da9773eeafc74b0809dfde128d9d0fba3a7b8f 100644 --- a/src/Quasar/Wayland/Protocol/Core.hs +++ b/src/Quasar/Wayland/Protocol/Core.hs @@ -55,6 +55,7 @@ module Quasar.Wayland.Protocol.Core ( import Control.Concurrent.STM import Control.Monad (replicateM_) import Control.Monad.Catch +import Control.Monad.Fix (mfix) import Control.Monad.Reader (ReaderT, runReaderT, ask, lift) import Data.Bifunctor qualified as Bifunctor import Data.Binary @@ -70,6 +71,7 @@ import Data.HashMap.Strict qualified as HM import Data.Proxy import Data.String (IsString(..)) import Data.Void (absurd) +import GHC.Conc (unsafeIOToSTM) import GHC.TypeLits import Quasar.Prelude @@ -240,7 +242,7 @@ class IsInterfaceSide s i => IsInterfaceHandler s i a where -- | Data kind data Side = Client | Server -data Object s i = IsInterfaceSide s i => Object GenericObjectId (WireCallback s i) +data Object s i = IsInterfaceSide s i => Object (ProtocolHandle s) GenericObjectId (WireCallback s i) instance IsInterface i => Show (Object s i) where show = showObject @@ -256,7 +258,7 @@ class IsObjectSide a where describeDownMessage :: a -> Opcode -> BSL.ByteString -> String instance forall s i. IsInterface i => IsObject (Object s i) where - objectId (Object oId _) = oId + objectId (Object _ oId _) = oId objectInterfaceName _ = interfaceName @i instance forall s i. IsInterfaceSide s i => IsObjectSide (Object s i) where @@ -377,6 +379,8 @@ newtype ProtocolHandle (s :: Side) = ProtocolHandle { -- | Protocol state handle, containing state for a non-failed protocol (should be kept in a 'ProtocolStateVar') data ProtocolState (s :: Side) = ProtocolState { + protocolKey :: Unique, + protocolHandle :: ProtocolHandle s, bytesReceivedVar :: TVar Int64, bytesSentVar :: TVar Int64, inboxDecoderVar :: TVar (Decoder RawMessage), @@ -387,6 +391,9 @@ data ProtocolState (s :: Side) = ProtocolState { type ProtocolM s a = ReaderT (ProtocolState s) STM a +askProtocol :: ProtocolM s (ProtocolHandle s) +askProtocol = (.protocolHandle) <$> ask + readProtocolVar :: (ProtocolState s -> TVar a) -> ProtocolM s a readProtocolVar fn = do state <- ask @@ -422,9 +429,20 @@ initializeProtocol wlDisplayWireCallback initializationAction = do bytesSentVar <- newTVar 0 inboxDecoderVar <- newTVar $ runGetIncremental getRawMessage outboxVar <- newTVar Nothing - objectsVar <- newTVar $ HM.fromList [(wlDisplayId, (SomeObject wlDisplay))] + protocolKey <- unsafeIOToSTM newUnique + objectsVar <- newTVar $ HM.empty nextIdVar <- newTVar (initialId @s) + + -- Create uninitialized to avoid use of a diverging 'mfix' + stateVar <- newTVar (Left impossibleCodePath) + + let protocol = ProtocolHandle { + stateVar + } + let state = ProtocolState { + protocolHandle = protocol, + protocolKey, bytesReceivedVar, bytesSentVar, inboxDecoderVar, @@ -432,24 +450,23 @@ initializeProtocol wlDisplayWireCallback initializationAction = do objectsVar, nextIdVar } - stateVar <- newTVar (Right state) - let protocol = ProtocolHandle { - stateVar - } + writeTVar stateVar (Right state) + + let wlDisplay = Object protocol wlDisplayId wlDisplayWireCallback + modifyTVar' objectsVar (HM.insert wlDisplayId (SomeObject wlDisplay)) + result <- runReaderT (initializationAction wlDisplay) state pure (result, protocol) where wlDisplayId :: GenericObjectId wlDisplayId = GenericObjectId 1 - wlDisplay :: Object s wl_display - wlDisplay = Object wlDisplayId wlDisplayWireCallback -- | Run a protocol action in 'IO'. If an exception occurs, it is stored as a protocol failure and is then -- re-thrown. -- -- Throws an exception, if the protocol is already in a failed state. runProtocolTransaction :: MonadIO m => ProtocolHandle s -> ProtocolM s a -> m a -runProtocolTransaction (ProtocolHandle stateVar) action = do +runProtocolTransaction (protocol@ProtocolHandle{stateVar}) action = do result <- liftIO $ atomically do readTVar stateVar >>= \case -- Protocol is already in a failed state @@ -526,9 +543,10 @@ newObjectFromId -> WireCallback s i -> ProtocolM s (Object s i) newObjectFromId (NewId oId) callback = do + protocol <- askProtocol let genericObjectId = toGenericObjectId oId - object = Object genericObjectId callback + object = Object protocol genericObjectId callback someObject = SomeObject object modifyProtocolVar (.objectsVar) (HM.insert genericObjectId someObject) pure object @@ -588,7 +606,7 @@ getMessageAction => Object s i -> Opcode -> Get (ProtocolM s ()) -getMessageAction object@(Object _ objectHandler) opcode = do +getMessageAction object@(Object _ _ objectHandler) opcode = do verifyMessage <- getWireDown object opcode pure $ handleMessage objectHandler object =<< verifyMessage