diff --git a/src/Quasar/Wayland/Core.hs b/src/Quasar/Wayland/Core.hs index 997db9e6cbf3587aab35f2d77483ba3fc36f2e79..0f015ad7fcb98600bc4aa0dbd01447b0059f7949 100644 --- a/src/Quasar/Wayland/Core.hs +++ b/src/Quasar/Wayland/Core.hs @@ -1,20 +1,30 @@ module Quasar.Wayland.Core ( ObjectId, Opcode, + IsInterface(..), + Side(..), + Object, + IsSomeObject, + IsMessage(..), ProtocolState, ClientProtocolState, - initialClientProtocolState, - --ServerProtocolState, - --initialServerProtocolState, + ServerProtocolState, + ClientCallback, + ServerCallback, + Callback(..), Request, Event, + ProtocolStep, initialProtocolState, + sendMessage, feedInput, - takeOutbox, + setException, ) where -import Control.Monad.State (State) +import Control.Monad.Catch +import Control.Monad.State (StateT, runStateT, state, modify) import Control.Monad.State qualified as State +import Data.Binary import Data.Binary.Get import Data.Binary.Put import Data.Bits ((.&.), (.|.), shiftL, shiftR) @@ -23,17 +33,62 @@ import Data.ByteString qualified as BS import Data.ByteString.Lazy qualified as BSL import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict qualified as HM +import Data.Maybe (isJust) +import Data.Void (absurd) import Quasar.Prelude type ObjectId = Word32 -type ObjectType = String type Opcode = Word16 -data Object = Object { - objectId :: ObjectId, - objectType :: ObjectType -} + +-- | A wayland interface +class (Binary (TRequest a), Binary (TEvent a)) => IsInterface a where + type TRequest a + type TEvent a + interfaceName :: String + +class IsInterface a => IsObject (s :: Side) a where + type Up s a + type Down s a + +data Side = Client | Server + +data Object s m a = IsInterface a => Object ObjectId (Callback s m a) + +instance IsInterface a => IsObject 'Client a where + type Up 'Client a = TRequest a + type Down 'Client a = TEvent a + +instance IsInterface a => IsObject 'Server a where + type Up 'Server a = TEvent a + type Down 'Server a = TRequest a + +instance forall s m a. IsInterface a => IsSomeObject (Object s m a) where + objectId (Object oId _) = oId + objectInterfaceName _ = interfaceName @a + +mkObject :: forall s m a. IsInterface a => ObjectId -> Callback s m a -> Object s m a +mkObject oId callback = Object @s @m @a oId callback + + +class IsSomeObject a where + objectId :: a -> ObjectId + objectInterfaceName :: a -> String + +-- | Wayland object quantification wrapper +data SomeObject = forall a. IsSomeObject a => SomeObject a + +instance IsSomeObject SomeObject where + objectId (SomeObject object) = objectId object + objectInterfaceName (SomeObject object) = objectInterfaceName object + + +class IsMessage a where + messageName :: a -> String + +instance IsMessage Void where + messageName = absurd data Argument @@ -61,16 +116,16 @@ putArgument (NewIdArgument x) = putWord32host x putArgument _ = undefined -type ClientProtocolState = ProtocolState Request Event -type ServerProtocolState = ProtocolState Event Request +type ClientProtocolState m = ProtocolState 'Client m +type ServerProtocolState m = ProtocolState 'Server m -data ProtocolState up down = ProtocolState { +data ProtocolState (s :: Side) m = ProtocolState { + protocolException :: Maybe SomeException, bytesReceived :: Word64, bytesSent :: Word64, - parser :: Decoder down, - inboxDecoder :: Decoder down, + inboxDecoder :: Decoder (ObjectId, Opcode, BSL.ByteString), outbox :: Maybe Put, - objects :: HashMap ObjectId Object + objects :: HashMap ObjectId SomeObject } data Request = Request ObjectId Opcode BSL.ByteString @@ -78,71 +133,148 @@ data Request = Request ObjectId Opcode BSL.ByteString data Event = Event ObjectId Opcode (Either BSL.ByteString (Word32, BSL.ByteString, Word32)) deriving stock Show -initialClientProtocolState :: ClientProtocolState -initialClientProtocolState = initialProtocolState decodeEvent - -initialProtocolState :: Get down -> ProtocolState up down -initialProtocolState downGet = sendInitialMessage ProtocolState { - bytesReceived = 0, - bytesSent = 0, - parser = runGetIncremental downGet, - inboxDecoder = runGetIncremental downGet, - outbox = Nothing, - objects = HM.singleton 1 (Object 1 "wl_display") + +type ClientCallback m a = Callback 'Client m a +type ServerCallback m a = Callback 'Server m a + +data Callback s m a = Callback { + messageCallback :: Object s m a -> Down s a -> StateT (ProtocolState s m) m () } -sendInitialMessage :: ProtocolState up down -> ProtocolState up down -sendInitialMessage = sendMessage 1 1 [NewIdArgument 2] +-- * Exceptions + +data CallbackFailed = CallbackFailed SomeException + deriving stock Show + deriving anyclass Exception + +data ParserFailed = ParserFailed String + deriving stock Show + deriving anyclass Exception + +-- * Monad plumbing -feedInput :: forall up down. ByteString -> ProtocolState up down -> ([down], ProtocolState up down) -feedInput bytes = State.runState do - State.modify (receive bytes) - go +type ProtocolStep s m a = ProtocolState s m -> m (Either SomeException a, Maybe BSL.ByteString, ProtocolState s m) + +protocolStep :: forall s m a. MonadCatch m => StateT (ProtocolState s m) m a -> ProtocolStep s m a +protocolStep action inState = do + mapM_ throwM inState.protocolException + (result, (outbox, outState)) <- fmap takeOutbox . storeExceptionIfFailed <$> runStateT (try action) inState + pure (result, outbox, outState) where - go :: State (ProtocolState up down) [down] - go = State.state takeDownMsg >>= \case - Nothing -> pure [] - Just msg -> (msg :) <$> go + storeExceptionIfFailed :: (Either SomeException a, ProtocolState s m) -> (Either SomeException a, ProtocolState s m) + storeExceptionIfFailed (Left ex, st) = (Left ex, setException ex st) + storeExceptionIfFailed x = x + setException :: (MonadCatch m, Exception e) => e -> (ProtocolState s m) -> (ProtocolState s m) + setException ex st = + if isJust st.protocolException + then st + else st{protocolException = Just (toException ex)} +-- * Exported functions -receive :: forall up down. ByteString -> ProtocolState up down -> ProtocolState up down -receive bytes state = state { - bytesReceived = state.bytesReceived + fromIntegral (BS.length bytes), - inboxDecoder = pushChunk state.inboxDecoder bytes -} +initialProtocolState + :: forall wl_display s m. IsInterface wl_display + => Callback s m wl_display + -> ProtocolState s m +initialProtocolState wlDisplayCallback = sendInitialMessage initialState + where + wlDisplay :: Object s m wl_display + wlDisplay = mkObject 1 wlDisplayCallback + initialState :: ProtocolState s m + initialState = ProtocolState { + protocolException = Nothing, + bytesReceived = 0, + bytesSent = 0, + inboxDecoder = runGetIncremental getRawMessage, + outbox = Nothing, + objects = HM.singleton 1 (SomeObject wlDisplay) + } + +-- | Feed the protocol newly received data +feedInput :: MonadCatch m => ByteString -> ProtocolStep s m () +feedInput bytes = protocolStep do + feed + runCallbacks + where + feed = modify \st -> st { + bytesReceived = st.bytesReceived + fromIntegral (BS.length bytes), + inboxDecoder = pushChunk st.inboxDecoder bytes + } + +sendMessage :: MonadCatch m => Object s m a -> Up s a -> ProtocolStep s m () +sendMessage object message = protocolStep do + undefined message + runCallbacks + +setException :: (MonadCatch m, Exception e) => e -> ProtocolStep s m () +setException ex = protocolStep do + modify \st -> st{protocolException = Just (toException ex)} + +-- * Internals + +-- | Take data that has to be sent (if available) +takeOutbox :: MonadCatch m => ProtocolState s m -> (Maybe BSL.ByteString, ProtocolState s m) +takeOutbox st = (runPut <$> st.outbox, st{outbox = Nothing}) + + +sendInitialMessage :: ProtocolState s m -> ProtocolState s m +sendInitialMessage = sendMessageInternal 1 1 [NewIdArgument 2] -takeDownMsg :: forall up down. ProtocolState up down -> (Maybe down, ProtocolState up down) -takeDownMsg state = (result, state{inboxDecoder = newDecoder}) +runCallbacks :: MonadCatch m => StateT (ProtocolState s m) m () +runCallbacks = receiveRawMessage >>= \case + Nothing -> pure () + Just message -> do + traceM $ show message + runCallbacks + + +type RawMessage = (ObjectId, Opcode, BSL.ByteString) + +getRawMessage :: Get RawMessage +getRawMessage = do + oId <- getWord32host + sizeAndOpcode <- getWord32host + let + size = fromIntegral (sizeAndOpcode `shiftR` 16) - 8 + opcode = fromIntegral (sizeAndOpcode .&. 0xFFFF) + body <- getLazyByteString size + pure (oId, opcode, body) + +receiveRawMessage :: MonadCatch m => StateT (ProtocolState s m) m (Maybe RawMessage) +receiveRawMessage = do + st <- State.get + (result, newDecoder) <- checkDecoder st.inboxDecoder + State.put st{inboxDecoder = newDecoder} + + pure result where - result :: Maybe down - newDecoder :: Decoder down - (result, newDecoder) = checkDecoder state.inboxDecoder - checkDecoder :: Decoder down -> (Maybe down, Decoder down) - checkDecoder (Fail _ _ _) = undefined - checkDecoder x@(Partial _) = (Nothing, x) - checkDecoder (Done leftovers _ result) = (Just result, pushChunk state.parser leftovers) + checkDecoder :: MonadCatch m => Decoder RawMessage -> StateT (ProtocolState s m) m (Maybe RawMessage, Decoder RawMessage) + checkDecoder d@(Fail _ _ message) = throwM (ParserFailed message) + checkDecoder x@(Partial _) = pure (Nothing, x) + checkDecoder (Done leftovers _ result) = pure (Just result, pushChunk (runGetIncremental getRawMessage) leftovers) decodeEvent :: Get Event decodeEvent = do - objectId <- getWord32host + oId <- getWord32host sizeAndOpcode <- getWord32host let size = fromIntegral (sizeAndOpcode `shiftR` 16) - 8 opcode = fromIntegral (sizeAndOpcode .&. 0xFFFF) - body <- if (objectId == 2 && opcode == 0) + body <- if (oId == 2 && opcode == 0) then Right <$> parseGlobal else Left <$> getLazyByteString size <* skipPadding - pure $ Event objectId opcode body + pure $ Event oId opcode body where parseGlobal :: Get (Word32, BSL.ByteString, Word32) parseGlobal = (,,) <$> getWord32host <*> getWaylandString <*> getWord32host - getWaylandString :: Get BSL.ByteString - getWaylandString = do - size <- getWord32host - Just (string, 0) <- BSL.unsnoc <$> getLazyByteString (fromIntegral size) - skipPadding - pure string + +getWaylandString :: Get BSL.ByteString +getWaylandString = do + size <- getWord32host + Just (string, 0) <- BSL.unsnoc <$> getLazyByteString (fromIntegral size) + skipPadding + pure string skipPadding :: Get () skipPadding = do @@ -150,9 +282,9 @@ skipPadding = do skip $ fromIntegral ((4 - (bytes `mod` 4)) `mod` 4) -sendMessage :: ObjectId -> Opcode -> [Argument] -> ProtocolState up down -> ProtocolState up down -sendMessage objectId opcode args = sendRaw do - putWord32host objectId +sendMessageInternal :: ObjectId -> Opcode -> [Argument] -> ProtocolState s m -> ProtocolState s m +sendMessageInternal oId opcode args = sendRaw do + putWord32host oId putWord32host $ (fromIntegral msgSize `shiftL` 16) .|. fromIntegral opcode mapM_ putArgument args -- TODO padding @@ -162,11 +294,7 @@ sendMessage objectId opcode args = sendRaw do msgSizeInteger :: Integer msgSizeInteger = foldr ((+) . (fromIntegral . argumentSize)) 8 args :: Integer -sendRaw :: Put -> ProtocolState up down -> ProtocolState up down +sendRaw :: Put -> ProtocolState s m -> ProtocolState s m sendRaw put oldState = oldState { outbox = Just (maybe put (<> put) oldState.outbox) } - -takeOutbox :: ProtocolState up down -> (Maybe BSL.ByteString, ProtocolState up down) -takeOutbox state = (runPut <$> state.outbox, state{outbox = Nothing}) -akeOutbox state = (runPut <$> state.outbox, state{outbox = Nothing})