Skip to content
Snippets Groups Projects
Commit f5aafcaf authored by Jens Nolte's avatar Jens Nolte
Browse files

Add types for bidirectional connection

parent b7df69ef
No related branches found
No related tags found
No related merge requests found
module Quasar.Wayland.Core ( module Quasar.Wayland.Core (
ObjectId, ObjectId,
Opcode, Opcode,
IsInterface(..),
Side(..),
Object,
IsSomeObject,
IsMessage(..),
ProtocolState, ProtocolState,
ClientProtocolState, ClientProtocolState,
initialClientProtocolState, ServerProtocolState,
--ServerProtocolState, ClientCallback,
--initialServerProtocolState, ServerCallback,
Callback(..),
Request, Request,
Event, Event,
ProtocolStep,
initialProtocolState, initialProtocolState,
sendMessage,
feedInput, feedInput,
takeOutbox, setException,
) where ) where
import Control.Monad.State (State) import Control.Monad.Catch
import Control.Monad.State (StateT, runStateT, state, modify)
import Control.Monad.State qualified as State import Control.Monad.State qualified as State
import Data.Binary
import Data.Binary.Get import Data.Binary.Get
import Data.Binary.Put import Data.Binary.Put
import Data.Bits ((.&.), (.|.), shiftL, shiftR) import Data.Bits ((.&.), (.|.), shiftL, shiftR)
...@@ -23,17 +33,62 @@ import Data.ByteString qualified as BS ...@@ -23,17 +33,62 @@ import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BSL import Data.ByteString.Lazy qualified as BSL
import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HM import Data.HashMap.Strict qualified as HM
import Data.Maybe (isJust)
import Data.Void (absurd)
import Quasar.Prelude import Quasar.Prelude
type ObjectId = Word32 type ObjectId = Word32
type ObjectType = String
type Opcode = Word16 type Opcode = Word16
data Object = Object {
objectId :: ObjectId, -- | A wayland interface
objectType :: ObjectType class (Binary (TRequest a), Binary (TEvent a)) => IsInterface a where
} type TRequest a
type TEvent a
interfaceName :: String
class IsInterface a => IsObject (s :: Side) a where
type Up s a
type Down s a
data Side = Client | Server
data Object s m a = IsInterface a => Object ObjectId (Callback s m a)
instance IsInterface a => IsObject 'Client a where
type Up 'Client a = TRequest a
type Down 'Client a = TEvent a
instance IsInterface a => IsObject 'Server a where
type Up 'Server a = TEvent a
type Down 'Server a = TRequest a
instance forall s m a. IsInterface a => IsSomeObject (Object s m a) where
objectId (Object oId _) = oId
objectInterfaceName _ = interfaceName @a
mkObject :: forall s m a. IsInterface a => ObjectId -> Callback s m a -> Object s m a
mkObject oId callback = Object @s @m @a oId callback
class IsSomeObject a where
objectId :: a -> ObjectId
objectInterfaceName :: a -> String
-- | Wayland object quantification wrapper
data SomeObject = forall a. IsSomeObject a => SomeObject a
instance IsSomeObject SomeObject where
objectId (SomeObject object) = objectId object
objectInterfaceName (SomeObject object) = objectInterfaceName object
class IsMessage a where
messageName :: a -> String
instance IsMessage Void where
messageName = absurd
data Argument data Argument
...@@ -61,16 +116,16 @@ putArgument (NewIdArgument x) = putWord32host x ...@@ -61,16 +116,16 @@ putArgument (NewIdArgument x) = putWord32host x
putArgument _ = undefined putArgument _ = undefined
type ClientProtocolState = ProtocolState Request Event type ClientProtocolState m = ProtocolState 'Client m
type ServerProtocolState = ProtocolState Event Request type ServerProtocolState m = ProtocolState 'Server m
data ProtocolState up down = ProtocolState { data ProtocolState (s :: Side) m = ProtocolState {
protocolException :: Maybe SomeException,
bytesReceived :: Word64, bytesReceived :: Word64,
bytesSent :: Word64, bytesSent :: Word64,
parser :: Decoder down, inboxDecoder :: Decoder (ObjectId, Opcode, BSL.ByteString),
inboxDecoder :: Decoder down,
outbox :: Maybe Put, outbox :: Maybe Put,
objects :: HashMap ObjectId Object objects :: HashMap ObjectId SomeObject
} }
data Request = Request ObjectId Opcode BSL.ByteString data Request = Request ObjectId Opcode BSL.ByteString
...@@ -78,71 +133,148 @@ data Request = Request ObjectId Opcode BSL.ByteString ...@@ -78,71 +133,148 @@ data Request = Request ObjectId Opcode BSL.ByteString
data Event = Event ObjectId Opcode (Either BSL.ByteString (Word32, BSL.ByteString, Word32)) data Event = Event ObjectId Opcode (Either BSL.ByteString (Word32, BSL.ByteString, Word32))
deriving stock Show deriving stock Show
initialClientProtocolState :: ClientProtocolState
initialClientProtocolState = initialProtocolState decodeEvent type ClientCallback m a = Callback 'Client m a
type ServerCallback m a = Callback 'Server m a
initialProtocolState :: Get down -> ProtocolState up down
initialProtocolState downGet = sendInitialMessage ProtocolState { data Callback s m a = Callback {
bytesReceived = 0, messageCallback :: Object s m a -> Down s a -> StateT (ProtocolState s m) m ()
bytesSent = 0,
parser = runGetIncremental downGet,
inboxDecoder = runGetIncremental downGet,
outbox = Nothing,
objects = HM.singleton 1 (Object 1 "wl_display")
} }
sendInitialMessage :: ProtocolState up down -> ProtocolState up down -- * Exceptions
sendInitialMessage = sendMessage 1 1 [NewIdArgument 2]
data CallbackFailed = CallbackFailed SomeException
deriving stock Show
deriving anyclass Exception
data ParserFailed = ParserFailed String
deriving stock Show
deriving anyclass Exception
-- * Monad plumbing
feedInput :: forall up down. ByteString -> ProtocolState up down -> ([down], ProtocolState up down) type ProtocolStep s m a = ProtocolState s m -> m (Either SomeException a, Maybe BSL.ByteString, ProtocolState s m)
feedInput bytes = State.runState do
State.modify (receive bytes) protocolStep :: forall s m a. MonadCatch m => StateT (ProtocolState s m) m a -> ProtocolStep s m a
go protocolStep action inState = do
mapM_ throwM inState.protocolException
(result, (outbox, outState)) <- fmap takeOutbox . storeExceptionIfFailed <$> runStateT (try action) inState
pure (result, outbox, outState)
where where
go :: State (ProtocolState up down) [down] storeExceptionIfFailed :: (Either SomeException a, ProtocolState s m) -> (Either SomeException a, ProtocolState s m)
go = State.state takeDownMsg >>= \case storeExceptionIfFailed (Left ex, st) = (Left ex, setException ex st)
Nothing -> pure [] storeExceptionIfFailed x = x
Just msg -> (msg :) <$> go setException :: (MonadCatch m, Exception e) => e -> (ProtocolState s m) -> (ProtocolState s m)
setException ex st =
if isJust st.protocolException
then st
else st{protocolException = Just (toException ex)}
-- * Exported functions
receive :: forall up down. ByteString -> ProtocolState up down -> ProtocolState up down initialProtocolState
receive bytes state = state { :: forall wl_display s m. IsInterface wl_display
bytesReceived = state.bytesReceived + fromIntegral (BS.length bytes), => Callback s m wl_display
inboxDecoder = pushChunk state.inboxDecoder bytes -> ProtocolState s m
} initialProtocolState wlDisplayCallback = sendInitialMessage initialState
where
wlDisplay :: Object s m wl_display
wlDisplay = mkObject 1 wlDisplayCallback
initialState :: ProtocolState s m
initialState = ProtocolState {
protocolException = Nothing,
bytesReceived = 0,
bytesSent = 0,
inboxDecoder = runGetIncremental getRawMessage,
outbox = Nothing,
objects = HM.singleton 1 (SomeObject wlDisplay)
}
-- | Feed the protocol newly received data
feedInput :: MonadCatch m => ByteString -> ProtocolStep s m ()
feedInput bytes = protocolStep do
feed
runCallbacks
where
feed = modify \st -> st {
bytesReceived = st.bytesReceived + fromIntegral (BS.length bytes),
inboxDecoder = pushChunk st.inboxDecoder bytes
}
sendMessage :: MonadCatch m => Object s m a -> Up s a -> ProtocolStep s m ()
sendMessage object message = protocolStep do
undefined message
runCallbacks
setException :: (MonadCatch m, Exception e) => e -> ProtocolStep s m ()
setException ex = protocolStep do
modify \st -> st{protocolException = Just (toException ex)}
-- * Internals
-- | Take data that has to be sent (if available)
takeOutbox :: MonadCatch m => ProtocolState s m -> (Maybe BSL.ByteString, ProtocolState s m)
takeOutbox st = (runPut <$> st.outbox, st{outbox = Nothing})
sendInitialMessage :: ProtocolState s m -> ProtocolState s m
sendInitialMessage = sendMessageInternal 1 1 [NewIdArgument 2]
takeDownMsg :: forall up down. ProtocolState up down -> (Maybe down, ProtocolState up down) runCallbacks :: MonadCatch m => StateT (ProtocolState s m) m ()
takeDownMsg state = (result, state{inboxDecoder = newDecoder}) runCallbacks = receiveRawMessage >>= \case
Nothing -> pure ()
Just message -> do
traceM $ show message
runCallbacks
type RawMessage = (ObjectId, Opcode, BSL.ByteString)
getRawMessage :: Get RawMessage
getRawMessage = do
oId <- getWord32host
sizeAndOpcode <- getWord32host
let
size = fromIntegral (sizeAndOpcode `shiftR` 16) - 8
opcode = fromIntegral (sizeAndOpcode .&. 0xFFFF)
body <- getLazyByteString size
pure (oId, opcode, body)
receiveRawMessage :: MonadCatch m => StateT (ProtocolState s m) m (Maybe RawMessage)
receiveRawMessage = do
st <- State.get
(result, newDecoder) <- checkDecoder st.inboxDecoder
State.put st{inboxDecoder = newDecoder}
pure result
where where
result :: Maybe down checkDecoder :: MonadCatch m => Decoder RawMessage -> StateT (ProtocolState s m) m (Maybe RawMessage, Decoder RawMessage)
newDecoder :: Decoder down checkDecoder d@(Fail _ _ message) = throwM (ParserFailed message)
(result, newDecoder) = checkDecoder state.inboxDecoder checkDecoder x@(Partial _) = pure (Nothing, x)
checkDecoder :: Decoder down -> (Maybe down, Decoder down) checkDecoder (Done leftovers _ result) = pure (Just result, pushChunk (runGetIncremental getRawMessage) leftovers)
checkDecoder (Fail _ _ _) = undefined
checkDecoder x@(Partial _) = (Nothing, x)
checkDecoder (Done leftovers _ result) = (Just result, pushChunk state.parser leftovers)
decodeEvent :: Get Event decodeEvent :: Get Event
decodeEvent = do decodeEvent = do
objectId <- getWord32host oId <- getWord32host
sizeAndOpcode <- getWord32host sizeAndOpcode <- getWord32host
let let
size = fromIntegral (sizeAndOpcode `shiftR` 16) - 8 size = fromIntegral (sizeAndOpcode `shiftR` 16) - 8
opcode = fromIntegral (sizeAndOpcode .&. 0xFFFF) opcode = fromIntegral (sizeAndOpcode .&. 0xFFFF)
body <- if (objectId == 2 && opcode == 0) body <- if (oId == 2 && opcode == 0)
then Right <$> parseGlobal then Right <$> parseGlobal
else Left <$> getLazyByteString size <* skipPadding else Left <$> getLazyByteString size <* skipPadding
pure $ Event objectId opcode body pure $ Event oId opcode body
where where
parseGlobal :: Get (Word32, BSL.ByteString, Word32) parseGlobal :: Get (Word32, BSL.ByteString, Word32)
parseGlobal = (,,) <$> getWord32host <*> getWaylandString <*> getWord32host parseGlobal = (,,) <$> getWord32host <*> getWaylandString <*> getWord32host
getWaylandString :: Get BSL.ByteString
getWaylandString = do getWaylandString :: Get BSL.ByteString
size <- getWord32host getWaylandString = do
Just (string, 0) <- BSL.unsnoc <$> getLazyByteString (fromIntegral size) size <- getWord32host
skipPadding Just (string, 0) <- BSL.unsnoc <$> getLazyByteString (fromIntegral size)
pure string skipPadding
pure string
skipPadding :: Get () skipPadding :: Get ()
skipPadding = do skipPadding = do
...@@ -150,9 +282,9 @@ skipPadding = do ...@@ -150,9 +282,9 @@ skipPadding = do
skip $ fromIntegral ((4 - (bytes `mod` 4)) `mod` 4) skip $ fromIntegral ((4 - (bytes `mod` 4)) `mod` 4)
sendMessage :: ObjectId -> Opcode -> [Argument] -> ProtocolState up down -> ProtocolState up down sendMessageInternal :: ObjectId -> Opcode -> [Argument] -> ProtocolState s m -> ProtocolState s m
sendMessage objectId opcode args = sendRaw do sendMessageInternal oId opcode args = sendRaw do
putWord32host objectId putWord32host oId
putWord32host $ (fromIntegral msgSize `shiftL` 16) .|. fromIntegral opcode putWord32host $ (fromIntegral msgSize `shiftL` 16) .|. fromIntegral opcode
mapM_ putArgument args mapM_ putArgument args
-- TODO padding -- TODO padding
...@@ -162,11 +294,7 @@ sendMessage objectId opcode args = sendRaw do ...@@ -162,11 +294,7 @@ sendMessage objectId opcode args = sendRaw do
msgSizeInteger :: Integer msgSizeInteger :: Integer
msgSizeInteger = foldr ((+) . (fromIntegral . argumentSize)) 8 args :: Integer msgSizeInteger = foldr ((+) . (fromIntegral . argumentSize)) 8 args :: Integer
sendRaw :: Put -> ProtocolState up down -> ProtocolState up down sendRaw :: Put -> ProtocolState s m -> ProtocolState s m
sendRaw put oldState = oldState { sendRaw put oldState = oldState {
outbox = Just (maybe put (<> put) oldState.outbox) outbox = Just (maybe put (<> put) oldState.outbox)
} }
takeOutbox :: ProtocolState up down -> (Maybe BSL.ByteString, ProtocolState up down)
takeOutbox state = (runPut <$> state.outbox, state{outbox = Nothing})
akeOutbox state = (runPut <$> state.outbox, state{outbox = Nothing})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment