Skip to content
Snippets Groups Projects
Core.hs 14.23 KiB
module Quasar.Wayland.Core (
  ObjectId,
  Opcode,
  Fixed,
  IsSide,
  Side(..),
  IsInterface(..),
  IsInterfaceSide(..),
  Object,
  IsObject(..),
  IsObject,
  IsMessage(..),
  ProtocolState,
  ClientProtocolState,
  ServerProtocolState,
  ClientCallback,
  ServerCallback,
  Callback(..),
  ProtocolStep,
  initialProtocolState,
  sendMessage,
  feedInput,
  setException,

  -- Message decoder operations
  WireFormat(..),
  dropRemaining,
) where

import Control.Monad (replicateM_)
import Control.Monad.Catch
import Control.Monad.Catch.Pure
import Control.Monad.Reader (ReaderT, runReaderT)
import Control.Monad.Reader qualified as Reader
import Control.Monad.Writer (WriterT, runWriterT, execWriterT, tell)
import Control.Monad.State (StateT, runStateT, lift)
import Control.Monad.State qualified as State
import Data.Binary
import Data.Binary.Get
import Data.Binary.Put
import Data.Bits ((.&.), (.|.), shiftL, shiftR)
import Data.ByteString (ByteString)
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BSL
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HM
import Data.Kind
import Data.Maybe (isJust)
import Data.Typeable (Typeable, cast)
import Data.Void (absurd)
import GHC.TypeLits
import Quasar.Prelude


type ObjectId = Word32
type Opcode = Word16

-- | Signed 24.8 decimal numbers.
newtype Fixed = Fixed Word32
  deriving Eq

newtype NewId = NewId ObjectId


dropRemaining :: Get ()
dropRemaining = void getRemainingLazyByteString


class WireFormat a where
  type Argument a
  putArgument :: Argument a -> PutM ()
  getArgument :: Get (Argument a)

instance WireFormat "int" where
  type Argument "int" = Int32
  putArgument = putInt32host
  getArgument = getInt32host

instance WireFormat "uint" where
  type Argument "uint" = Word32
  putArgument = putWord32host
  getArgument = getWord32host

instance WireFormat "fixed" where
  type Argument "fixed" = Fixed
  putArgument (Fixed repr) = putWord32host repr
  getArgument = Fixed <$> getWord32host

instance WireFormat "string" where
  type Argument "string" = BS.ByteString
  putArgument = putWaylandBlob
  getArgument = getWaylandBlob

instance WireFormat "object" where
  type Argument "object" = ObjectId
  putArgument = putWord32host
  getArgument = getWord32host

instance WireFormat "new_id" where
  type Argument "new_id" = NewId
  putArgument (NewId newId) = putWord32host newId
  getArgument = NewId <$> getWord32host

instance WireFormat "array" where
  type Argument "array" = BS.ByteString
  putArgument = putWaylandBlob
  getArgument = getWaylandBlob

instance WireFormat "fd" where
  type Argument "fd" = Void
  putArgument = undefined
  getArgument = undefined


-- | A wayland interface
class
  (
    Binary (Request i),
    Binary (Event i),
    IsMessage (Request i),
    IsMessage (Event i)
  )
  => IsInterface i where
  type Request i
  type Event i
  interfaceName :: String

class IsSide (s :: Side) where
  type Up s i
  type Down s i
  getDown :: forall m i. IsInterface i => Object s m i -> Opcode -> Get (Down s i)

instance IsSide 'Client where
  type Up 'Client i = Request i
  type Down 'Client i = Event i
  getDown :: forall m i. IsInterface i => Object 'Client m i -> Opcode -> Get (Down 'Client i)
  getDown = getMessage @(Down 'Client i)

instance IsSide 'Server where
  type Up 'Server i = Event i
  type Down 'Server i = Request i
  getDown :: forall m i. IsInterface i => Object 'Server m i -> Opcode -> Get (Down 'Server i)
  getDown = getMessage @(Down 'Server i)


-- | Empty class, only required to combine constraints
class (IsSide s, IsInterface i, IsMessage (Up s i), IsMessage (Down s i)) => IsInterfaceSide (s :: Side) i


-- | Data kind
data Side = Client | Server

data Object s m i = IsInterfaceSide s i => Object ObjectId (Callback s m i)

class IsObject a where
  objectId :: a -> ObjectId
  objectInterfaceName :: a -> String

class IsObjectSide a where
  describeUpMessage :: a -> Opcode -> BSL.ByteString -> String
  describeDownMessage :: a -> Opcode -> BSL.ByteString -> String

instance forall s m i. IsInterface i => IsObject (Object s m i) where
  objectId (Object oId _) = oId
  objectInterfaceName _ = interfaceName @i

instance forall s m i. IsInterfaceSide s i => IsObjectSide (Object s m i) where
  describeUpMessage object opcode body =
    objectInterfaceName object <> "@" <> show (objectId object) <>
    "." <> fromMaybe "[invalidOpcode]" (opcodeName @(Up s i) opcode) <>
    " (" <> show (BSL.length body) <> "B)"
  describeDownMessage object opcode body =
    objectInterfaceName object <> "@" <> show (objectId object) <>
    "." <> fromMaybe "[invalidOpcode]" (opcodeName @(Down s i) opcode) <>
    " (" <> show (BSL.length body) <> "B)"

-- | Wayland object quantification wrapper
data SomeObject s m
  = forall i. IsInterfaceSide s i => SomeObject (Object s m i)
  | UnknownObject String ObjectId

instance IsObject (SomeObject s m) where
  objectId (SomeObject object) = objectId object
  objectId (UnknownObject _ oId) = oId
  objectInterfaceName (SomeObject object) = objectInterfaceName object
  objectInterfaceName (UnknownObject interface _) = interface

instance IsObjectSide (SomeObject s m) where
  describeUpMessage (SomeObject object) = describeUpMessage object
  describeUpMessage (UnknownObject interface oId) =
    \opcode body -> interface <> "@" <> show oId <> ".#" <> show opcode <>
      " (" <> show (BSL.length body) <> "B, unknown)"
  describeDownMessage (SomeObject object) = describeDownMessage object
  describeDownMessage (UnknownObject interface oId) =
    \opcode body -> interface <> "@" <> show oId <> ".#" <> show opcode <>
      " (" <> show (BSL.length body) <> "B, unknown)"


class IsMessage a where
  opcodeName :: Opcode -> Maybe String
  showMessage :: IsMessage a => a -> String
  getMessage :: IsInterface i => Object s m i -> Opcode -> Get a
  putMessage :: a -> PutM ()

instance IsMessage Void where
  opcodeName _ = Nothing
  showMessage = absurd
  getMessage = invalidOpcode
  putMessage = absurd
invalidOpcode :: IsInterface i => Object s m i -> Opcode -> Get a
invalidOpcode object opcode =
  fail $ "Invalid opcode " <> show opcode <> " on " <> objectInterfaceName object <> "@" <> show (objectId object)



-- TODO remove
data DynamicArgument
  = IntArgument Int32
  | UIntArgument Word32
  -- TODO
  | FixedArgument Void
  | StringArgument String
  | ObjectArgument ObjectId
  | NewIdArgument ObjectId
  | FdArgument ()

argumentSize :: DynamicArgument -> Word16
argumentSize (IntArgument _) = 4
argumentSize (UIntArgument _) = 4
argumentSize (ObjectArgument _) = 4
argumentSize (NewIdArgument _) = 4
argumentSize _ = undefined

putDynamicArgument :: DynamicArgument -> Put
putDynamicArgument (IntArgument x) = putInt32host x
putDynamicArgument (UIntArgument x) = putWord32host x
putDynamicArgument (ObjectArgument x) = putWord32host x
putDynamicArgument (NewIdArgument x) = putWord32host x
putDynamicArgument _ = undefined



type ClientProtocolState m = ProtocolState 'Client m
type ServerProtocolState m = ProtocolState 'Server m

data ProtocolState (s :: Side) m = ProtocolState {
  protocolException :: Maybe SomeException,
  bytesReceived :: Word64,
  bytesSent :: Word64,
  inboxDecoder :: Decoder RawMessage,
  outbox :: Maybe Put,
  objects :: HashMap ObjectId (SomeObject s m)
}


type ClientCallback m i = Callback 'Client m i
type ServerCallback m i = Callback 'Server m i

data Callback s m i = Callback {
  messageCallback :: Object s m i -> Down s i -> StateT (ProtocolState s m) m ()
}

-- * Exceptions

data CallbackFailed = CallbackFailed SomeException
  deriving stock Show
  deriving anyclass Exception

data ParserFailed = ParserFailed String String
  deriving stock Show
  deriving anyclass Exception

data ProtocolException = ProtocolException String
  deriving stock Show
  deriving anyclass Exception

-- * Monad plumbing

type ProtocolStep s m a = ProtocolState s m -> m (Either SomeException a, Maybe BSL.ByteString, ProtocolState s m)

protocolStep :: forall s m a. MonadCatch m => StateT (ProtocolState s m) m a -> ProtocolStep s m a
protocolStep action inState = do
  mapM_ throwM inState.protocolException
  (result, (outbox, outState)) <- fmap takeOutbox . storeExceptionIfFailed <$> runStateT (try action) inState
  pure (result, outbox, outState)
  where
    storeExceptionIfFailed :: (Either SomeException a, ProtocolState s m) -> (Either SomeException a, ProtocolState s m)
    storeExceptionIfFailed (Left ex, st) = (Left ex, setException ex st)
    storeExceptionIfFailed x = x
    setException :: (MonadCatch m, Exception e) => e -> (ProtocolState s m) -> (ProtocolState s m)
    setException ex st =
      if isJust st.protocolException
        then st
        else st{protocolException = Just (toException ex)}


-- * Exported functions

initialProtocolState
  :: forall wl_display wl_registry s m. (IsInterfaceSide s wl_display, IsInterfaceSide s wl_registry)
  => Callback s m wl_display
  -> Callback s m wl_registry
  -> ProtocolState s m
initialProtocolState wlDisplayCallback wlRegistryCallback = sendInitialMessage initialState
  where
    wlDisplay :: Object s m wl_display
    wlDisplay = Object 1 wlDisplayCallback
    wlRegistry :: Object s m wl_registry
    wlRegistry = Object 2 wlRegistryCallback
    initialState :: ProtocolState s m
    initialState = ProtocolState {
      protocolException = Nothing,
      bytesReceived = 0,
      bytesSent = 0,
      inboxDecoder = runGetIncremental getRawMessage,
      outbox = Nothing,
      objects = HM.fromList [(1, (SomeObject wlDisplay)), (2, (SomeObject wlRegistry))]
    }

-- | Feed the protocol newly received data
feedInput :: (IsSide s, MonadCatch m) => ByteString -> ProtocolStep s m ()
feedInput bytes = protocolStep do
  feed
  runCallbacks
  where
    feed = State.modify \st -> st {
      bytesReceived = st.bytesReceived + fromIntegral (BS.length bytes),
      inboxDecoder = pushChunk st.inboxDecoder bytes
    }

sendMessage :: (IsSide s, MonadCatch m) => Object s m i -> Up s i -> ProtocolStep s m ()
sendMessage object message = protocolStep do
  undefined message
  runCallbacks

setException :: (MonadCatch m, Exception e) => e -> ProtocolStep s m ()
setException ex = protocolStep do
  State.modify \st -> st{protocolException = Just (toException ex)}

-- * Internals

-- | Take data that has to be sent (if available)
takeOutbox :: MonadCatch m => ProtocolState s m ->  (Maybe BSL.ByteString, ProtocolState s m)
takeOutbox st = (outboxBytes, st{outbox = Nothing})
  where
    outboxBytes = if isJust st.protocolException then Nothing else runPut <$> st.outbox


sendInitialMessage :: ProtocolState s m -> ProtocolState s m
sendInitialMessage = sendMessageInternal 1 1 [NewIdArgument 2]

runCallbacks :: (IsSide s, MonadCatch m) => StateT (ProtocolState s m) m ()
runCallbacks = receiveRawMessage >>= \case
  Nothing -> pure ()
  Just rawMessage -> do
    handleMessage rawMessage
    runCallbacks

handleMessage :: forall s m. (IsSide s, MonadCatch m) => RawMessage -> StateT (ProtocolState s m) m ()
handleMessage rawMessage@(oId, opcode, body) = do
  st <- State.get
  case HM.lookup oId st.objects of
    Nothing -> throwM $ ProtocolException $ "Received message with invalid object id " <> show oId

    Just (SomeObject object) -> do
      case runGetOrFail (getMessageAction st.objects object rawMessage) body of
        Left (_, _, message) ->
          throwM $ ParserFailed (describeDownMessage object opcode body) message
        Right ("", _, result) ->
          traceM $ "Received message " <> (describeDownMessage object opcode body)
        Right (leftovers, _, _) ->
          throwM $ ParserFailed (describeDownMessage object opcode body) (show (BSL.length leftovers) <> "B not parsed")

    Just (UnknownObject interface oId) -> do
      throwM $ ProtocolException $ "Received message for unknown object " <> interface <> "@" <> show oId

getMessageAction
  :: (IsSide s, IsInterface i, MonadCatch m)
  => HashMap ObjectId (SomeObject s m)
  -> Object s m i
  -> RawMessage
  -> Get (ProtocolAction s m ())
getMessageAction objects object@(Object _ callback) (oId, opcode, body) = do
  message <- getDown object opcode
  pure $ traceM $ "Received message " <> describeDownMessage object opcode body

type ProtocolAction s m a = StateT (ProtocolState s m) m a

type RawMessage = (ObjectId, Opcode, BSL.ByteString)

receiveRawMessage :: forall s m a. MonadCatch m => StateT (ProtocolState s m) m (Maybe RawMessage)
receiveRawMessage = do
  st <- State.get
  (result, newDecoder) <- checkDecoder st.inboxDecoder
  State.put st{inboxDecoder = newDecoder}
  pure result
  where
    checkDecoder
      :: MonadCatch m
      => Decoder RawMessage
      -> StateT (ProtocolState s m) m (Maybe RawMessage, Decoder RawMessage)
    checkDecoder (Fail _ _ message) = throwM (ParserFailed "RawMessage" message)
    checkDecoder x@(Partial _) = pure (Nothing, x)
    checkDecoder (Done leftovers _ result) = pure (Just result, pushChunk (runGetIncremental getRawMessage) leftovers)


getRawMessage :: Get RawMessage
getRawMessage = do
  oId <- getWord32host
  sizeAndOpcode <- getWord32host
  let
    size = fromIntegral (sizeAndOpcode `shiftR` 16) - 8
    opcode = fromIntegral (sizeAndOpcode .&. 0xFFFF)
  body <- getLazyByteString size
  pure (oId, opcode, body)

getWaylandBlob :: Get BS.ByteString
getWaylandBlob = do
  size <- getWord32host
  Just (string, 0) <- BS.unsnoc <$> getByteString (fromIntegral size)
  skipPadding
  pure string

putWaylandBlob :: BS.ByteString -> Put
putWaylandBlob blob = do
  let size = BS.length blob
  putWord32host (fromIntegral size)
  putByteString blob
  putWord8 0
  replicateM_ ((4 - (size `mod` 4)) `mod` 4) (putWord8 0)


skipPadding :: Get ()
skipPadding = do
  bytes <- bytesRead
  skip $ fromIntegral ((4 - (bytes `mod` 4)) `mod` 4)


sendMessageInternal :: ObjectId -> Opcode -> [DynamicArgument] -> ProtocolState s m -> ProtocolState s m
sendMessageInternal oId opcode args = sendRaw do
  putWord32host oId
  putWord32host $ (fromIntegral msgSize `shiftL` 16) .|. fromIntegral opcode
  mapM_ putDynamicArgument args
  -- TODO padding
  where
    msgSize :: Word16
    msgSize = if msgSizeInteger <= fromIntegral (maxBound :: Word16) then fromIntegral msgSizeInteger else undefined
    msgSizeInteger :: Integer
    msgSizeInteger = foldr ((+) . (fromIntegral . argumentSize)) 8 args :: Integer

sendRaw :: Put -> ProtocolState s m -> ProtocolState s m
sendRaw x oldState = oldState {
  outbox = Just (maybe x (<> x) oldState.outbox)
}