module Quasar.Wayland.Core ( ObjectId, Opcode, Fixed, IsSide, Side(..), IsInterface(..), IsInterfaceSide(..), Object, IsObject(..), IsObject, IsMessage(..), ProtocolState, ClientProtocolState, ServerProtocolState, ClientCallback, ServerCallback, Callback(..), ProtocolStep, initialProtocolState, sendMessage, feedInput, setException, -- Message decoder operations WireFormat(..), dropRemaining, ) where import Control.Monad (replicateM_) import Control.Monad.Catch import Control.Monad.Catch.Pure import Control.Monad.Reader (ReaderT, runReaderT) import Control.Monad.Reader qualified as Reader import Control.Monad.Writer (WriterT, runWriterT, execWriterT, tell) import Control.Monad.State (StateT, runStateT, lift) import Control.Monad.State qualified as State import Data.Binary import Data.Binary.Get import Data.Binary.Put import Data.Bits ((.&.), (.|.), shiftL, shiftR) import Data.ByteString (ByteString) import Data.ByteString qualified as BS import Data.ByteString.Lazy qualified as BSL import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict qualified as HM import Data.Kind import Data.Maybe (isJust) import Data.Typeable (Typeable, cast) import Data.Void (absurd) import GHC.TypeLits import Quasar.Prelude type ObjectId = Word32 type Opcode = Word16 -- | Signed 24.8 decimal numbers. newtype Fixed = Fixed Word32 deriving Eq newtype NewId = NewId ObjectId dropRemaining :: Get () dropRemaining = void getRemainingLazyByteString class WireFormat a where type Argument a putArgument :: Argument a -> PutM () getArgument :: Get (Argument a) instance WireFormat "int" where type Argument "int" = Int32 putArgument = putInt32host getArgument = getInt32host instance WireFormat "uint" where type Argument "uint" = Word32 putArgument = putWord32host getArgument = getWord32host instance WireFormat "fixed" where type Argument "fixed" = Fixed putArgument (Fixed repr) = putWord32host repr getArgument = Fixed <$> getWord32host instance WireFormat "string" where type Argument "string" = BS.ByteString putArgument = putWaylandBlob getArgument = getWaylandBlob instance WireFormat "object" where type Argument "object" = ObjectId putArgument = putWord32host getArgument = getWord32host instance WireFormat "new_id" where type Argument "new_id" = NewId putArgument (NewId newId) = putWord32host newId getArgument = NewId <$> getWord32host instance WireFormat "array" where type Argument "array" = BS.ByteString putArgument = putWaylandBlob getArgument = getWaylandBlob instance WireFormat "fd" where type Argument "fd" = Void putArgument = undefined getArgument = undefined -- | A wayland interface class ( Binary (Request i), Binary (Event i), IsMessage (Request i), IsMessage (Event i) ) => IsInterface i where type Request i type Event i interfaceName :: String class IsSide (s :: Side) where type Up s i type Down s i getDown :: forall m i. IsInterface i => Object s m i -> Opcode -> Get (Down s i) instance IsSide 'Client where type Up 'Client i = Request i type Down 'Client i = Event i getDown :: forall m i. IsInterface i => Object 'Client m i -> Opcode -> Get (Down 'Client i) getDown = getMessage @(Down 'Client i) instance IsSide 'Server where type Up 'Server i = Event i type Down 'Server i = Request i getDown :: forall m i. IsInterface i => Object 'Server m i -> Opcode -> Get (Down 'Server i) getDown = getMessage @(Down 'Server i) -- | Empty class, only required to combine constraints class (IsSide s, IsInterface i, IsMessage (Up s i), IsMessage (Down s i)) => IsInterfaceSide (s :: Side) i -- | Data kind data Side = Client | Server data Object s m i = IsInterfaceSide s i => Object ObjectId (Callback s m i) class IsObject a where objectId :: a -> ObjectId objectInterfaceName :: a -> String class IsObjectSide a where describeUpMessage :: a -> Opcode -> BSL.ByteString -> String describeDownMessage :: a -> Opcode -> BSL.ByteString -> String instance forall s m i. IsInterface i => IsObject (Object s m i) where objectId (Object oId _) = oId objectInterfaceName _ = interfaceName @i instance forall s m i. IsInterfaceSide s i => IsObjectSide (Object s m i) where describeUpMessage object opcode body = objectInterfaceName object <> "@" <> show (objectId object) <> "." <> fromMaybe "[invalidOpcode]" (opcodeName @(Up s i) opcode) <> " (" <> show (BSL.length body) <> "B)" describeDownMessage object opcode body = objectInterfaceName object <> "@" <> show (objectId object) <> "." <> fromMaybe "[invalidOpcode]" (opcodeName @(Down s i) opcode) <> " (" <> show (BSL.length body) <> "B)" -- | Wayland object quantification wrapper data SomeObject s m = forall i. IsInterfaceSide s i => SomeObject (Object s m i) | UnknownObject String ObjectId instance IsObject (SomeObject s m) where objectId (SomeObject object) = objectId object objectId (UnknownObject _ oId) = oId objectInterfaceName (SomeObject object) = objectInterfaceName object objectInterfaceName (UnknownObject interface _) = interface instance IsObjectSide (SomeObject s m) where describeUpMessage (SomeObject object) = describeUpMessage object describeUpMessage (UnknownObject interface oId) = \opcode body -> interface <> "@" <> show oId <> ".#" <> show opcode <> " (" <> show (BSL.length body) <> "B, unknown)" describeDownMessage (SomeObject object) = describeDownMessage object describeDownMessage (UnknownObject interface oId) = \opcode body -> interface <> "@" <> show oId <> ".#" <> show opcode <> " (" <> show (BSL.length body) <> "B, unknown)" class IsMessage a where opcodeName :: Opcode -> Maybe String showMessage :: IsMessage a => a -> String getMessage :: IsInterface i => Object s m i -> Opcode -> Get a putMessage :: a -> PutM () instance IsMessage Void where opcodeName _ = Nothing showMessage = absurd getMessage = invalidOpcode putMessage = absurd invalidOpcode :: IsInterface i => Object s m i -> Opcode -> Get a invalidOpcode object opcode = fail $ "Invalid opcode " <> show opcode <> " on " <> objectInterfaceName object <> "@" <> show (objectId object) -- TODO remove data DynamicArgument = IntArgument Int32 | UIntArgument Word32 -- TODO | FixedArgument Void | StringArgument String | ObjectArgument ObjectId | NewIdArgument ObjectId | FdArgument () argumentSize :: DynamicArgument -> Word16 argumentSize (IntArgument _) = 4 argumentSize (UIntArgument _) = 4 argumentSize (ObjectArgument _) = 4 argumentSize (NewIdArgument _) = 4 argumentSize _ = undefined putDynamicArgument :: DynamicArgument -> Put putDynamicArgument (IntArgument x) = putInt32host x putDynamicArgument (UIntArgument x) = putWord32host x putDynamicArgument (ObjectArgument x) = putWord32host x putDynamicArgument (NewIdArgument x) = putWord32host x putDynamicArgument _ = undefined type ClientProtocolState m = ProtocolState 'Client m type ServerProtocolState m = ProtocolState 'Server m data ProtocolState (s :: Side) m = ProtocolState { protocolException :: Maybe SomeException, bytesReceived :: Word64, bytesSent :: Word64, inboxDecoder :: Decoder RawMessage, outbox :: Maybe Put, objects :: HashMap ObjectId (SomeObject s m) } type ClientCallback m i = Callback 'Client m i type ServerCallback m i = Callback 'Server m i data Callback s m i = Callback { messageCallback :: Object s m i -> Down s i -> StateT (ProtocolState s m) m () } -- * Exceptions data CallbackFailed = CallbackFailed SomeException deriving stock Show deriving anyclass Exception data ParserFailed = ParserFailed String String deriving stock Show deriving anyclass Exception data ProtocolException = ProtocolException String deriving stock Show deriving anyclass Exception -- * Monad plumbing type ProtocolStep s m a = ProtocolState s m -> m (Either SomeException a, Maybe BSL.ByteString, ProtocolState s m) protocolStep :: forall s m a. MonadCatch m => StateT (ProtocolState s m) m a -> ProtocolStep s m a protocolStep action inState = do mapM_ throwM inState.protocolException (result, (outbox, outState)) <- fmap takeOutbox . storeExceptionIfFailed <$> runStateT (try action) inState pure (result, outbox, outState) where storeExceptionIfFailed :: (Either SomeException a, ProtocolState s m) -> (Either SomeException a, ProtocolState s m) storeExceptionIfFailed (Left ex, st) = (Left ex, setException ex st) storeExceptionIfFailed x = x setException :: (MonadCatch m, Exception e) => e -> (ProtocolState s m) -> (ProtocolState s m) setException ex st = if isJust st.protocolException then st else st{protocolException = Just (toException ex)} -- * Exported functions initialProtocolState :: forall wl_display wl_registry s m. (IsInterfaceSide s wl_display, IsInterfaceSide s wl_registry) => Callback s m wl_display -> Callback s m wl_registry -> ProtocolState s m initialProtocolState wlDisplayCallback wlRegistryCallback = sendInitialMessage initialState where wlDisplay :: Object s m wl_display wlDisplay = Object 1 wlDisplayCallback wlRegistry :: Object s m wl_registry wlRegistry = Object 2 wlRegistryCallback initialState :: ProtocolState s m initialState = ProtocolState { protocolException = Nothing, bytesReceived = 0, bytesSent = 0, inboxDecoder = runGetIncremental getRawMessage, outbox = Nothing, objects = HM.fromList [(1, (SomeObject wlDisplay)), (2, (SomeObject wlRegistry))] } -- | Feed the protocol newly received data feedInput :: (IsSide s, MonadCatch m) => ByteString -> ProtocolStep s m () feedInput bytes = protocolStep do feed runCallbacks where feed = State.modify \st -> st { bytesReceived = st.bytesReceived + fromIntegral (BS.length bytes), inboxDecoder = pushChunk st.inboxDecoder bytes } sendMessage :: (IsSide s, MonadCatch m) => Object s m i -> Up s i -> ProtocolStep s m () sendMessage object message = protocolStep do undefined message runCallbacks setException :: (MonadCatch m, Exception e) => e -> ProtocolStep s m () setException ex = protocolStep do State.modify \st -> st{protocolException = Just (toException ex)} -- * Internals -- | Take data that has to be sent (if available) takeOutbox :: MonadCatch m => ProtocolState s m -> (Maybe BSL.ByteString, ProtocolState s m) takeOutbox st = (outboxBytes, st{outbox = Nothing}) where outboxBytes = if isJust st.protocolException then Nothing else runPut <$> st.outbox sendInitialMessage :: ProtocolState s m -> ProtocolState s m sendInitialMessage = sendMessageInternal 1 1 [NewIdArgument 2] runCallbacks :: (IsSide s, MonadCatch m) => StateT (ProtocolState s m) m () runCallbacks = receiveRawMessage >>= \case Nothing -> pure () Just rawMessage -> do handleMessage rawMessage runCallbacks handleMessage :: forall s m. (IsSide s, MonadCatch m) => RawMessage -> StateT (ProtocolState s m) m () handleMessage rawMessage@(oId, opcode, body) = do st <- State.get case HM.lookup oId st.objects of Nothing -> throwM $ ProtocolException $ "Received message with invalid object id " <> show oId Just (SomeObject object) -> do case runGetOrFail (getMessageAction st.objects object rawMessage) body of Left (_, _, message) -> throwM $ ParserFailed (describeDownMessage object opcode body) message Right ("", _, result) -> traceM $ "Received message " <> (describeDownMessage object opcode body) Right (leftovers, _, _) -> throwM $ ParserFailed (describeDownMessage object opcode body) (show (BSL.length leftovers) <> "B not parsed") Just (UnknownObject interface oId) -> do throwM $ ProtocolException $ "Received message for unknown object " <> interface <> "@" <> show oId getMessageAction :: (IsSide s, IsInterface i, MonadCatch m) => HashMap ObjectId (SomeObject s m) -> Object s m i -> RawMessage -> Get (ProtocolAction s m ()) getMessageAction objects object@(Object _ callback) (oId, opcode, body) = do message <- getDown object opcode pure $ traceM $ "Received message " <> describeDownMessage object opcode body type ProtocolAction s m a = StateT (ProtocolState s m) m a type RawMessage = (ObjectId, Opcode, BSL.ByteString) receiveRawMessage :: forall s m a. MonadCatch m => StateT (ProtocolState s m) m (Maybe RawMessage) receiveRawMessage = do st <- State.get (result, newDecoder) <- checkDecoder st.inboxDecoder State.put st{inboxDecoder = newDecoder} pure result where checkDecoder :: MonadCatch m => Decoder RawMessage -> StateT (ProtocolState s m) m (Maybe RawMessage, Decoder RawMessage) checkDecoder (Fail _ _ message) = throwM (ParserFailed "RawMessage" message) checkDecoder x@(Partial _) = pure (Nothing, x) checkDecoder (Done leftovers _ result) = pure (Just result, pushChunk (runGetIncremental getRawMessage) leftovers) getRawMessage :: Get RawMessage getRawMessage = do oId <- getWord32host sizeAndOpcode <- getWord32host let size = fromIntegral (sizeAndOpcode `shiftR` 16) - 8 opcode = fromIntegral (sizeAndOpcode .&. 0xFFFF) body <- getLazyByteString size pure (oId, opcode, body) getWaylandBlob :: Get BS.ByteString getWaylandBlob = do size <- getWord32host Just (string, 0) <- BS.unsnoc <$> getByteString (fromIntegral size) skipPadding pure string putWaylandBlob :: BS.ByteString -> Put putWaylandBlob blob = do let size = BS.length blob putWord32host (fromIntegral size) putByteString blob putWord8 0 replicateM_ ((4 - (size `mod` 4)) `mod` 4) (putWord8 0) skipPadding :: Get () skipPadding = do bytes <- bytesRead skip $ fromIntegral ((4 - (bytes `mod` 4)) `mod` 4) sendMessageInternal :: ObjectId -> Opcode -> [DynamicArgument] -> ProtocolState s m -> ProtocolState s m sendMessageInternal oId opcode args = sendRaw do putWord32host oId putWord32host $ (fromIntegral msgSize `shiftL` 16) .|. fromIntegral opcode mapM_ putDynamicArgument args -- TODO padding where msgSize :: Word16 msgSize = if msgSizeInteger <= fromIntegral (maxBound :: Word16) then fromIntegral msgSizeInteger else undefined msgSizeInteger :: Integer msgSizeInteger = foldr ((+) . (fromIntegral . argumentSize)) 8 args :: Integer sendRaw :: Put -> ProtocolState s m -> ProtocolState s m sendRaw x oldState = oldState { outbox = Just (maybe x (<> x) oldState.outbox) }