Skip to content
Snippets Groups Projects
Commit f5aafcaf authored by Jens Nolte's avatar Jens Nolte
Browse files

Add types for bidirectional connection

parent b7df69ef
No related branches found
No related tags found
No related merge requests found
module Quasar.Wayland.Core (
ObjectId,
Opcode,
IsInterface(..),
Side(..),
Object,
IsSomeObject,
IsMessage(..),
ProtocolState,
ClientProtocolState,
initialClientProtocolState,
--ServerProtocolState,
--initialServerProtocolState,
ServerProtocolState,
ClientCallback,
ServerCallback,
Callback(..),
Request,
Event,
ProtocolStep,
initialProtocolState,
sendMessage,
feedInput,
takeOutbox,
setException,
) where
import Control.Monad.State (State)
import Control.Monad.Catch
import Control.Monad.State (StateT, runStateT, state, modify)
import Control.Monad.State qualified as State
import Data.Binary
import Data.Binary.Get
import Data.Binary.Put
import Data.Bits ((.&.), (.|.), shiftL, shiftR)
......@@ -23,17 +33,62 @@ import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BSL
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HM
import Data.Maybe (isJust)
import Data.Void (absurd)
import Quasar.Prelude
type ObjectId = Word32
type ObjectType = String
type Opcode = Word16
data Object = Object {
objectId :: ObjectId,
objectType :: ObjectType
}
-- | A wayland interface
class (Binary (TRequest a), Binary (TEvent a)) => IsInterface a where
type TRequest a
type TEvent a
interfaceName :: String
class IsInterface a => IsObject (s :: Side) a where
type Up s a
type Down s a
data Side = Client | Server
data Object s m a = IsInterface a => Object ObjectId (Callback s m a)
instance IsInterface a => IsObject 'Client a where
type Up 'Client a = TRequest a
type Down 'Client a = TEvent a
instance IsInterface a => IsObject 'Server a where
type Up 'Server a = TEvent a
type Down 'Server a = TRequest a
instance forall s m a. IsInterface a => IsSomeObject (Object s m a) where
objectId (Object oId _) = oId
objectInterfaceName _ = interfaceName @a
mkObject :: forall s m a. IsInterface a => ObjectId -> Callback s m a -> Object s m a
mkObject oId callback = Object @s @m @a oId callback
class IsSomeObject a where
objectId :: a -> ObjectId
objectInterfaceName :: a -> String
-- | Wayland object quantification wrapper
data SomeObject = forall a. IsSomeObject a => SomeObject a
instance IsSomeObject SomeObject where
objectId (SomeObject object) = objectId object
objectInterfaceName (SomeObject object) = objectInterfaceName object
class IsMessage a where
messageName :: a -> String
instance IsMessage Void where
messageName = absurd
data Argument
......@@ -61,16 +116,16 @@ putArgument (NewIdArgument x) = putWord32host x
putArgument _ = undefined
type ClientProtocolState = ProtocolState Request Event
type ServerProtocolState = ProtocolState Event Request
type ClientProtocolState m = ProtocolState 'Client m
type ServerProtocolState m = ProtocolState 'Server m
data ProtocolState up down = ProtocolState {
data ProtocolState (s :: Side) m = ProtocolState {
protocolException :: Maybe SomeException,
bytesReceived :: Word64,
bytesSent :: Word64,
parser :: Decoder down,
inboxDecoder :: Decoder down,
inboxDecoder :: Decoder (ObjectId, Opcode, BSL.ByteString),
outbox :: Maybe Put,
objects :: HashMap ObjectId Object
objects :: HashMap ObjectId SomeObject
}
data Request = Request ObjectId Opcode BSL.ByteString
......@@ -78,71 +133,148 @@ data Request = Request ObjectId Opcode BSL.ByteString
data Event = Event ObjectId Opcode (Either BSL.ByteString (Word32, BSL.ByteString, Word32))
deriving stock Show
initialClientProtocolState :: ClientProtocolState
initialClientProtocolState = initialProtocolState decodeEvent
initialProtocolState :: Get down -> ProtocolState up down
initialProtocolState downGet = sendInitialMessage ProtocolState {
bytesReceived = 0,
bytesSent = 0,
parser = runGetIncremental downGet,
inboxDecoder = runGetIncremental downGet,
outbox = Nothing,
objects = HM.singleton 1 (Object 1 "wl_display")
type ClientCallback m a = Callback 'Client m a
type ServerCallback m a = Callback 'Server m a
data Callback s m a = Callback {
messageCallback :: Object s m a -> Down s a -> StateT (ProtocolState s m) m ()
}
sendInitialMessage :: ProtocolState up down -> ProtocolState up down
sendInitialMessage = sendMessage 1 1 [NewIdArgument 2]
-- * Exceptions
data CallbackFailed = CallbackFailed SomeException
deriving stock Show
deriving anyclass Exception
data ParserFailed = ParserFailed String
deriving stock Show
deriving anyclass Exception
-- * Monad plumbing
feedInput :: forall up down. ByteString -> ProtocolState up down -> ([down], ProtocolState up down)
feedInput bytes = State.runState do
State.modify (receive bytes)
go
type ProtocolStep s m a = ProtocolState s m -> m (Either SomeException a, Maybe BSL.ByteString, ProtocolState s m)
protocolStep :: forall s m a. MonadCatch m => StateT (ProtocolState s m) m a -> ProtocolStep s m a
protocolStep action inState = do
mapM_ throwM inState.protocolException
(result, (outbox, outState)) <- fmap takeOutbox . storeExceptionIfFailed <$> runStateT (try action) inState
pure (result, outbox, outState)
where
go :: State (ProtocolState up down) [down]
go = State.state takeDownMsg >>= \case
Nothing -> pure []
Just msg -> (msg :) <$> go
storeExceptionIfFailed :: (Either SomeException a, ProtocolState s m) -> (Either SomeException a, ProtocolState s m)
storeExceptionIfFailed (Left ex, st) = (Left ex, setException ex st)
storeExceptionIfFailed x = x
setException :: (MonadCatch m, Exception e) => e -> (ProtocolState s m) -> (ProtocolState s m)
setException ex st =
if isJust st.protocolException
then st
else st{protocolException = Just (toException ex)}
-- * Exported functions
receive :: forall up down. ByteString -> ProtocolState up down -> ProtocolState up down
receive bytes state = state {
bytesReceived = state.bytesReceived + fromIntegral (BS.length bytes),
inboxDecoder = pushChunk state.inboxDecoder bytes
}
initialProtocolState
:: forall wl_display s m. IsInterface wl_display
=> Callback s m wl_display
-> ProtocolState s m
initialProtocolState wlDisplayCallback = sendInitialMessage initialState
where
wlDisplay :: Object s m wl_display
wlDisplay = mkObject 1 wlDisplayCallback
initialState :: ProtocolState s m
initialState = ProtocolState {
protocolException = Nothing,
bytesReceived = 0,
bytesSent = 0,
inboxDecoder = runGetIncremental getRawMessage,
outbox = Nothing,
objects = HM.singleton 1 (SomeObject wlDisplay)
}
-- | Feed the protocol newly received data
feedInput :: MonadCatch m => ByteString -> ProtocolStep s m ()
feedInput bytes = protocolStep do
feed
runCallbacks
where
feed = modify \st -> st {
bytesReceived = st.bytesReceived + fromIntegral (BS.length bytes),
inboxDecoder = pushChunk st.inboxDecoder bytes
}
sendMessage :: MonadCatch m => Object s m a -> Up s a -> ProtocolStep s m ()
sendMessage object message = protocolStep do
undefined message
runCallbacks
setException :: (MonadCatch m, Exception e) => e -> ProtocolStep s m ()
setException ex = protocolStep do
modify \st -> st{protocolException = Just (toException ex)}
-- * Internals
-- | Take data that has to be sent (if available)
takeOutbox :: MonadCatch m => ProtocolState s m -> (Maybe BSL.ByteString, ProtocolState s m)
takeOutbox st = (runPut <$> st.outbox, st{outbox = Nothing})
sendInitialMessage :: ProtocolState s m -> ProtocolState s m
sendInitialMessage = sendMessageInternal 1 1 [NewIdArgument 2]
takeDownMsg :: forall up down. ProtocolState up down -> (Maybe down, ProtocolState up down)
takeDownMsg state = (result, state{inboxDecoder = newDecoder})
runCallbacks :: MonadCatch m => StateT (ProtocolState s m) m ()
runCallbacks = receiveRawMessage >>= \case
Nothing -> pure ()
Just message -> do
traceM $ show message
runCallbacks
type RawMessage = (ObjectId, Opcode, BSL.ByteString)
getRawMessage :: Get RawMessage
getRawMessage = do
oId <- getWord32host
sizeAndOpcode <- getWord32host
let
size = fromIntegral (sizeAndOpcode `shiftR` 16) - 8
opcode = fromIntegral (sizeAndOpcode .&. 0xFFFF)
body <- getLazyByteString size
pure (oId, opcode, body)
receiveRawMessage :: MonadCatch m => StateT (ProtocolState s m) m (Maybe RawMessage)
receiveRawMessage = do
st <- State.get
(result, newDecoder) <- checkDecoder st.inboxDecoder
State.put st{inboxDecoder = newDecoder}
pure result
where
result :: Maybe down
newDecoder :: Decoder down
(result, newDecoder) = checkDecoder state.inboxDecoder
checkDecoder :: Decoder down -> (Maybe down, Decoder down)
checkDecoder (Fail _ _ _) = undefined
checkDecoder x@(Partial _) = (Nothing, x)
checkDecoder (Done leftovers _ result) = (Just result, pushChunk state.parser leftovers)
checkDecoder :: MonadCatch m => Decoder RawMessage -> StateT (ProtocolState s m) m (Maybe RawMessage, Decoder RawMessage)
checkDecoder d@(Fail _ _ message) = throwM (ParserFailed message)
checkDecoder x@(Partial _) = pure (Nothing, x)
checkDecoder (Done leftovers _ result) = pure (Just result, pushChunk (runGetIncremental getRawMessage) leftovers)
decodeEvent :: Get Event
decodeEvent = do
objectId <- getWord32host
oId <- getWord32host
sizeAndOpcode <- getWord32host
let
size = fromIntegral (sizeAndOpcode `shiftR` 16) - 8
opcode = fromIntegral (sizeAndOpcode .&. 0xFFFF)
body <- if (objectId == 2 && opcode == 0)
body <- if (oId == 2 && opcode == 0)
then Right <$> parseGlobal
else Left <$> getLazyByteString size <* skipPadding
pure $ Event objectId opcode body
pure $ Event oId opcode body
where
parseGlobal :: Get (Word32, BSL.ByteString, Word32)
parseGlobal = (,,) <$> getWord32host <*> getWaylandString <*> getWord32host
getWaylandString :: Get BSL.ByteString
getWaylandString = do
size <- getWord32host
Just (string, 0) <- BSL.unsnoc <$> getLazyByteString (fromIntegral size)
skipPadding
pure string
getWaylandString :: Get BSL.ByteString
getWaylandString = do
size <- getWord32host
Just (string, 0) <- BSL.unsnoc <$> getLazyByteString (fromIntegral size)
skipPadding
pure string
skipPadding :: Get ()
skipPadding = do
......@@ -150,9 +282,9 @@ skipPadding = do
skip $ fromIntegral ((4 - (bytes `mod` 4)) `mod` 4)
sendMessage :: ObjectId -> Opcode -> [Argument] -> ProtocolState up down -> ProtocolState up down
sendMessage objectId opcode args = sendRaw do
putWord32host objectId
sendMessageInternal :: ObjectId -> Opcode -> [Argument] -> ProtocolState s m -> ProtocolState s m
sendMessageInternal oId opcode args = sendRaw do
putWord32host oId
putWord32host $ (fromIntegral msgSize `shiftL` 16) .|. fromIntegral opcode
mapM_ putArgument args
-- TODO padding
......@@ -162,11 +294,7 @@ sendMessage objectId opcode args = sendRaw do
msgSizeInteger :: Integer
msgSizeInteger = foldr ((+) . (fromIntegral . argumentSize)) 8 args :: Integer
sendRaw :: Put -> ProtocolState up down -> ProtocolState up down
sendRaw :: Put -> ProtocolState s m -> ProtocolState s m
sendRaw put oldState = oldState {
outbox = Just (maybe put (<> put) oldState.outbox)
}
takeOutbox :: ProtocolState up down -> (Maybe BSL.ByteString, ProtocolState up down)
takeOutbox state = (runPut <$> state.outbox, state{outbox = Nothing})
akeOutbox state = (runPut <$> state.outbox, state{outbox = Nothing})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment