Skip to content
Snippets Groups Projects
Core.hs 27.62 KiB
{-# LANGUAGE DeriveLift #-}

module Quasar.Wayland.Protocol.Core (
  ObjectId,
  GenericObjectId,
  NewId,
  GenericNewId,
  Opcode,
  Fixed(..),
  WlString(..),
  toString,
  fromString,
  IsSide(..),
  Side(..),
  IsInterface(..),
  interfaceName,
  Version,
  interfaceVersion,
  IsInterfaceSide(..),
  Object(objectProtocol),
  setEventHandler,
  setRequestHandler,
  setMessageHandler,
  getMessageHandler,
  NewObject,
  IsObject,
  IsMessage(..),
  ProtocolHandle,
  ProtocolM,

  -- * Protocol IO
  initializeProtocol,
  feedInput,
  setException,
  takeOutbox,
  runProtocolTransaction,
  runProtocolM,
  enterObject,

  -- * Low-level protocol interaction
  objectWireArgument,
  nullableObjectWireArgument,
  handleDestructor,
  checkObject,
  sendMessage,
  newObject,
  newObjectFromId,
  bindNewObject,
  getObject,
  getNullableObject,
  lookupObject,
  buildMessage,

  -- * wl_display interface
  handleWlDisplayError,
  handleWlDisplayDeleteId,

  -- * Protocol exceptions
  WireCallbackFailed(..),
  ParserFailed(..),
  ProtocolException(..),
  ProtocolUsageError(..),
  MaximumIdReached(..),
  ServerError(..),

  -- * Message decoder operations
  WireFormat(..),
  invalidOpcode,
) where
import Control.Concurrent.STM
import Control.Monad.Catch
import Control.Monad.Reader (ReaderT, runReaderT, ask, lift)
import Data.Binary
import Data.Binary.Get
import Data.Binary.Put
import Data.Bits ((.&.), (.|.), shiftL, shiftR)
import Data.ByteString (ByteString)
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BSL
import Data.ByteString.UTF8 qualified as BSUTF8
import Data.Foldable (toList)
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HM
import Data.Proxy
import Data.Sequence (Seq)
import Data.Sequence qualified as Seq
import Data.String (IsString(..))
import Data.Typeable (Typeable, cast)
import Data.Void (absurd)
import GHC.Conc (unsafeIOToSTM)
import GHC.TypeLits
import Quasar.Prelude
import System.Posix.Types (Fd(Fd))


newtype ObjectId (j :: Symbol) = ObjectId Word32
  deriving newtype (Eq, Show, Hashable)

newtype GenericObjectId = GenericObjectId Word32
  deriving newtype (Eq, Show, Hashable)

toGenericObjectId :: ObjectId j -> GenericObjectId
toGenericObjectId (ObjectId oId) = GenericObjectId oId

type Opcode = Word16

type Version = Word32


newtype NewId (j :: Symbol) = NewId (ObjectId j)
  deriving newtype (Eq, Show)

data GenericNewId = GenericNewId WlString Version Word32
  deriving stock (Eq, Show)


-- | Signed 24.8 decimal numbers.
newtype Fixed = Fixed Word32
  deriving newtype Eq

instance Show Fixed where
  show x = "[fixed " <> show x <> "]"


-- | A string. The encoding is not officially specified, but in practice UTF-8 is used.
--
-- Instances and functions in this library assume UTF-8, but the original data is also available by deconstructing.
newtype WlString = WlString BS.ByteString
  deriving newtype (Eq, Hashable)

instance Show WlString where
  show = show . toString

instance IsString WlString where
  fromString = WlString . BSUTF8.fromString

toString :: WlString -> String
toString (WlString bs) = BSUTF8.toString bs

data MessagePart = MessagePart Put Int (Seq Fd)

instance Semigroup MessagePart where
  (MessagePart px lx fx) <> (MessagePart py ly fy) = MessagePart (px <> py) (lx + ly) (fx <> fy)

instance Monoid MessagePart where
  mempty = MessagePart mempty 0 mempty


class (Eq a, Show a) => WireFormat a where
  putArgument :: a -> Either SomeException MessagePart
  getArgument :: Get (ProtocolM s a)
  showArgument :: a -> String

instance WireFormat Int32 where
  putArgument x = pure $ MessagePart (putInt32host x) 4 mempty
  getArgument = pure <$> getInt32host
  showArgument = show

instance WireFormat Word32 where
  putArgument x = pure $ MessagePart (putWord32host x) 4 mempty
  getArgument = pure <$> getWord32host
  showArgument = show

instance WireFormat Fixed where
  putArgument (Fixed repr) = pure $ MessagePart (putWord32host repr) 4 mempty
  getArgument = pure . Fixed <$> getWord32host
  showArgument = show

instance WireFormat WlString where
  putArgument (WlString x) = putWaylandString x
  getArgument = pure . WlString <$> getWaylandString
  showArgument = show

instance WireFormat BS.ByteString where
  putArgument x = putWaylandArray x
  getArgument = pure <$> getWaylandArray
  showArgument array = "[array " <> show (BS.length array) <> "B]"

instance KnownSymbol j => WireFormat (ObjectId (j :: Symbol)) where
  putArgument (ObjectId oId) = pure $ MessagePart (putWord32host oId) 4 mempty
  getArgument = pure . ObjectId <$> getWord32host
  showArgument (ObjectId oId) = symbolVal @j Proxy <> "@" <> show oId

instance WireFormat GenericObjectId where
  putArgument (GenericObjectId oId) = pure $ MessagePart (putWord32host oId) 4 mempty
  getArgument = pure . GenericObjectId <$> getWord32host
  showArgument oId = "[unknown]@" <> show oId

instance KnownSymbol j => WireFormat (NewId (j :: Symbol)) where
  putArgument (NewId newId) = putArgument newId
  getArgument = NewId <<$>> getArgument
  showArgument (NewId newId) = "new " <> symbolVal @j Proxy <> "@" <> show newId

instance WireFormat GenericNewId where
  putArgument (GenericNewId interface version newId) = do
    p1 <- putArgument interface
    p2 <- putArgument version
    p3 <- putArgument newId
    pure (p1 <> p2 <> p3)
  getArgument = GenericNewId <<$>> getArgument <<*>> getArgument <<*>> getArgument
  showArgument (GenericNewId interface version newId) = mconcat ["new ", toString interface, "[v", show version, "]@", show newId]

instance WireFormat Fd where
  putArgument fd = pure (MessagePart mempty 0 (Seq.singleton fd))
  getArgument = undefined
  showArgument (Fd fd) = "fd@" <> show fd

-- | Class for a proxy type (in the haskell sense) that describes a Wayland interface.
class (
    IsMessage (WireRequest i),
    IsMessage (WireEvent i),
    KnownSymbol (InterfaceName i),
    KnownNat (InterfaceVersion i),
    Typeable i
  )
  => IsInterface i where
  type RequestHandler i
  type EventHandler i
  type WireRequest i
  type WireEvent i
  type InterfaceName i :: Symbol
  type InterfaceVersion i :: Nat

interfaceName :: forall i. IsInterface i => WlString
interfaceName = fromString $ symbolVal @(InterfaceName i) Proxy

interfaceVersion :: forall i. IsInterface i => Word32
interfaceVersion = fromIntegral $ natVal @(InterfaceVersion i) Proxy

class Typeable s => IsSide (s :: Side) where
  type MessageHandler s i
  type WireUp s i
  type WireDown s i
  initialId :: Word32
  maximumId :: Word32

instance IsSide 'Client where
  type MessageHandler 'Client i = EventHandler i
  type WireUp 'Client i = WireRequest i
  type WireDown 'Client i = WireEvent i
  -- Id #1 is reserved for wl_display
  initialId = 2
  maximumId = 0xfeffffff

instance IsSide 'Server where
  type MessageHandler 'Server i = RequestHandler i
  type WireUp 'Server i = WireEvent i
  type WireDown 'Server i = WireRequest i
  initialId = 0xff000000
  maximumId = 0xffffffff


class (
    IsSide s,
    IsInterface i,
    IsMessage (WireUp s i),
    IsMessage (WireDown s i)
  )
  => IsInterfaceSide (s :: Side) i where
  handleMessage :: Object s i -> MessageHandler s i -> WireDown s i -> ProtocolM s ()


getWireDown :: forall s i. IsInterfaceSide s i => Object s i -> Opcode -> Get (ProtocolM s (WireDown s i))
getWireDown = getMessage @(WireDown s i)

putWireUp :: forall s i. IsInterfaceSide s i => Object s i -> WireUp s i -> Either SomeException (Opcode, MessagePart)
putWireUp _ = putMessage @(WireUp s i)


-- | Data kind
data Side = Client | Server
  deriving stock (Eq, Show)


-- | An object belonging to a wayland connection.
data Object s i = IsInterfaceSide s i => Object {
  objectProtocol :: (ProtocolHandle s),
  objectId :: ObjectId (InterfaceName i),
  messageHandler :: TVar (Maybe (MessageHandler s i)),
  destroyed :: TVar Bool
}


getMessageHandler :: Object s i -> STM (MessageHandler s i)
getMessageHandler object = maybe retry pure =<< readTVar object.messageHandler

setMessageHandler :: Object s i -> MessageHandler s i -> STM ()
setMessageHandler object = writeTVar object.messageHandler . Just

setRequestHandler :: Object 'Server i -> RequestHandler i -> STM ()
setRequestHandler = setMessageHandler

setEventHandler :: Object 'Client i -> EventHandler i -> STM ()
setEventHandler = setMessageHandler

-- | Type alias to indicate an object is created with a message.
type NewObject s i = Object s i

instance IsInterface i => Show (Object s i) where
  show = showObject

class IsObject a where
  genericObjectId :: a -> GenericObjectId
  objectInterfaceName :: a -> WlString
  showObject :: a -> String
  showObject object = toString (objectInterfaceName object) <> "@" <> show (genericObjectId object)

class IsObjectSide a where
  describeUpMessage :: a -> Opcode -> BSL.ByteString -> String
  describeDownMessage :: a -> Opcode -> BSL.ByteString -> String

instance forall s i. IsInterface i => IsObject (Object s i) where
  genericObjectId object = toGenericObjectId object.objectId
  objectInterfaceName _ = interfaceName @i

instance forall s i. IsInterfaceSide s i => IsObjectSide (Object s i) where
  describeUpMessage object opcode body = mconcat [
    toString (objectInterfaceName object), "@", show (genericObjectId object),
    ".", fromMaybe "[invalidOpcode]" (opcodeName @(WireUp s i) opcode),
    " (", show (BSL.length body), "B)"]
  describeDownMessage object opcode body = mconcat [
    toString (objectInterfaceName object), "@", show (genericObjectId object),
    ".", fromMaybe "[invalidOpcode]" (opcodeName @(WireDown s i) opcode),
    " (", show (BSL.length body), "B)"]

-- | Wayland object quantification wrapper
data SomeObject s = forall i. IsInterfaceSide s i => SomeObject (Object s i)

instance IsObject (SomeObject s) where
  genericObjectId (SomeObject object) = genericObjectId object
  objectInterfaceName (SomeObject object) = objectInterfaceName object

instance IsObjectSide (SomeObject s) where
  describeUpMessage (SomeObject object) = describeUpMessage object
  describeDownMessage (SomeObject object) = describeDownMessage object


class (Eq a, Show a) => IsMessage a where
  opcodeName :: Opcode -> Maybe String
  getMessage :: IsInterface i => Object s i -> Opcode -> Get (ProtocolM s a)
  putMessage :: a -> Either SomeException (Opcode, MessagePart)

instance IsMessage Void where
  opcodeName _ = Nothing
  getMessage = invalidOpcode
  putMessage = absurd
buildMessage :: Opcode -> [Either SomeException MessagePart] -> Either SomeException (Opcode, MessagePart)
buildMessage opcode parts = (opcode,) . mconcat <$> sequence parts

invalidOpcode :: IsInterface i => Object s i -> Opcode -> Get a
invalidOpcode object opcode = fail $ mconcat [
  "Invalid opcode ", show opcode, " on ", toString (objectInterfaceName object),
  "@", show (genericObjectId object)]

showObjectMessage :: (IsObject a, IsMessage b) => a -> b -> String
showObjectMessage object message =
  showObject object <> "." <> show message


-- * Exceptions

data WireCallbackFailed = WireCallbackFailed SomeException
  deriving stock Show
  deriving anyclass Exception

data ParserFailed = ParserFailed String String
  deriving stock Show
  deriving anyclass Exception

data ProtocolException = ProtocolException String
  deriving stock Show
  deriving anyclass Exception

data ProtocolUsageError = ProtocolUsageError String
  deriving stock Show
  deriving anyclass Exception

data MaximumIdReached = MaximumIdReached
  deriving stock Show
  deriving anyclass Exception

data ServerError = ServerError Word32 String
  deriving stock Show
  deriving anyclass Exception

data InvalidObject = InvalidObject String
  deriving stock Show
  deriving anyclass Exception

-- * Protocol state and monad plumbing

-- | Top-level protocol handle (used e.g. to send/receive data)
newtype ProtocolHandle (s :: Side) = ProtocolHandle {
  stateVar :: TVar (Either SomeException (ProtocolState s))
}

-- | Protocol state handle, containing state for a non-failed protocol (should be kept in a 'ProtocolStateVar')
data ProtocolState (s :: Side) = ProtocolState {
  protocolKey :: Unique,
  protocolHandle :: ProtocolHandle s,
  bytesReceivedVar :: TVar Int64,
  bytesSentVar :: TVar Int64,
  inboxDecoderVar :: TVar (Decoder RawMessage),
  inboxFdsVar :: TVar (Seq Fd),
  outboxVar :: TVar (Maybe Put),
  outboxFdsVar :: TVar (Seq Fd),
  objectsVar :: TVar (HashMap GenericObjectId (SomeObject s)),
  nextIdVar :: TVar Word32
}

type ProtocolM s a = ReaderT (ProtocolState s) STM a

askProtocol :: ProtocolM s (ProtocolHandle s)
askProtocol = (.protocolHandle) <$> ask

readProtocolVar :: (ProtocolState s -> TVar a) -> ProtocolM s a
readProtocolVar fn = do
  state <- ask
  lift $ readTVar (fn state)

writeProtocolVar :: (ProtocolState s -> TVar a) -> a -> ProtocolM s ()
writeProtocolVar fn x = do
  state <- ask
  lift $ writeTVar (fn state) x

modifyProtocolVar :: (ProtocolState s -> TVar a) -> (a -> a) -> ProtocolM s ()
modifyProtocolVar fn x = do
  state <- ask
  lift $ modifyTVar (fn state) x

modifyProtocolVar' :: (ProtocolState s -> TVar a) -> (a -> a) -> ProtocolM s ()
modifyProtocolVar' fn x = do
  state <- ask
  lift $ modifyTVar' (fn state) x

stateProtocolVar :: (ProtocolState s -> TVar a) -> (a -> (r, a)) -> ProtocolM s r
stateProtocolVar fn x = do
  state <- ask
  lift $ stateTVar (fn state) x

swapProtocolVar :: (ProtocolState s -> TVar a) -> a -> ProtocolM s a
swapProtocolVar fn x = do
  state <- ask
  lift $ swapTVar (fn state) x

initializeProtocol
  :: forall s wl_display a. (IsInterfaceSide s wl_display)
  => (ProtocolHandle s -> MessageHandler s wl_display)
  -> (Object s wl_display -> STM a)
  -> STM (a, ProtocolHandle s)
initializeProtocol wlDisplayMessageHandler initializationAction = do
  bytesReceivedVar <- newTVar 0
  bytesSentVar <- newTVar 0
  inboxDecoderVar <- newTVar $ runGetIncremental getRawMessage
  inboxFdsVar <- newTVar mempty
  outboxVar <- newTVar Nothing
  outboxFdsVar <- newTVar mempty
  protocolKey <- unsafeIOToSTM newUnique
  objectsVar <- newTVar $ HM.empty
  nextIdVar <- newTVar (initialId @s)

  -- Create uninitialized to avoid use of a diverging 'mfix'
  stateVar <- newTVar (Left unreachableCodePath)

  let protocol = ProtocolHandle {
    stateVar
  }

  let state = ProtocolState {
    protocolHandle = protocol,
    protocolKey,
    bytesReceivedVar,
    bytesSentVar,
    inboxDecoderVar,
    inboxFdsVar,
    outboxVar,
    outboxFdsVar,
    objectsVar,
    nextIdVar
  }
  writeTVar stateVar (Right state)

  messageHandlerVar <- newTVar (Just (wlDisplayMessageHandler protocol))
  destroyed <- newTVar False
  let wlDisplay = Object protocol wlDisplayId messageHandlerVar destroyed
  modifyTVar' objectsVar (HM.insert (toGenericObjectId wlDisplayId) (SomeObject wlDisplay))

  result <- initializationAction wlDisplay
  pure (result, protocol)
  where
    wlDisplayId :: ObjectId (InterfaceName wl_display)
    wlDisplayId = ObjectId 1

-- | Run a protocol action in 'IO'. If an exception occurs, it is stored as a protocol failure and is then
-- re-thrown.
--
-- Throws an exception, if the protocol is already in a failed state.
runProtocolTransaction :: MonadIO m => ProtocolHandle s -> ProtocolM s a -> m a
runProtocolTransaction ProtocolHandle{stateVar} action = do
  result <- liftIO $ atomically do
    readTVar stateVar >>= \case
      -- Protocol is already in a failed state
      Left ex -> throwM ex
      Right state -> do
        -- Run action, catch exceptions
        runReaderT (try action) state >>= \case
          Left ex -> do
            -- Action failed, change protocol state to failed
            writeTVar stateVar (Left ex)
            pure (Left ex)
          Right result -> do
            pure (Right result)
  -- Transaction is committed, rethrow exception if the action failed
  either (liftIO . throwM) pure result


-- | Run a 'ProtocolM'-action inside 'STM'.
--
-- Throws an exception, if the protocol is already in a failed state.
--
-- Exceptions are not handled (i.e. they usually reset the STM transaction and are not stored as a protocol failure).
runProtocolM :: ProtocolHandle s -> ProtocolM s a -> STM a
runProtocolM protocol action = either throwM (runReaderT action) =<< readTVar protocol.stateVar


-- | Feed the protocol newly received data.
feedInput :: (IsSide s, MonadIO m) => ProtocolHandle s -> ByteString -> [Fd] -> m ()
feedInput protocol bytes fds = runProtocolTransaction protocol do
  -- Exposing MonadIO instead of STM to the outside and using `runProtocolTransaction` here enforces correct exception
  -- handling.
  modifyProtocolVar' (.bytesReceivedVar) (+ fromIntegral (BS.length bytes))
  modifyProtocolVar (.inboxDecoderVar) (`pushChunk` bytes)
  modifyProtocolVar (.inboxFdsVar) (<> Seq.fromList fds)
  receiveMessages

-- | Set the protocol to a failed state, e.g. when the socket closed unexpectedly.
setException :: (Exception e, MonadIO m) => ProtocolHandle s -> e -> m ()
setException protocol ex = runProtocolTransaction protocol $ throwM ex

-- | Take data that has to be sent. Blocks until data is available.
takeOutbox :: MonadIO m => ProtocolHandle s -> m (BSL.ByteString, [Fd])
takeOutbox protocol = runProtocolTransaction protocol do
  mOutboxData <- stateProtocolVar (.outboxVar) (\mOutboxData -> (mOutboxData, Nothing))
  outboxData <- maybe (lift retry) pure mOutboxData
  let sendData = runPut outboxData
  modifyProtocolVar' (.bytesSentVar) (+ BSL.length sendData)
  fds <- swapProtocolVar (.outboxFdsVar) mempty
  pure (sendData, toList fds)


-- | Create an object. The caller is responsible for sending the 'NewId' immediately (exactly once and before using the
-- object).
--
-- For use in generated code.
newObject
  :: forall s i. IsInterfaceSide s i
  => Maybe (MessageHandler s i)
  -> ProtocolM s (Object s i, NewId (InterfaceName i))
newObject messageHandler = do
  oId <- allocateObjectId
  let newId = NewId @(InterfaceName i) oId
  object <- newObjectFromId messageHandler newId
  pure (object, newId)
  where
    allocateObjectId :: ProtocolM s (ObjectId (InterfaceName i))
    allocateObjectId = do
      id' <- readProtocolVar (.nextIdVar)

      let nextId' = id' + 1
      when (nextId' == maximumId @s) $ throwM MaximumIdReached

      writeProtocolVar (.nextIdVar) nextId'
      pure $ ObjectId id'


-- | Create an object from a received id. The caller is responsible for using a 'NewId' exactly once while handling an
-- incoming message
--
-- For use in generated code.
newObjectFromId
  :: forall s i. IsInterfaceSide s i
  => Maybe (MessageHandler s i)
  -> NewId (InterfaceName i)
  -> ProtocolM s (Object s i)
newObjectFromId messageHandler (NewId oId) = do
  protocol <- askProtocol
  messageHandlerVar <- lift $ newTVar messageHandler
  destroyed <- lift $ newTVar False
  let
    object = Object protocol oId messageHandlerVar destroyed
    someObject = SomeObject object
  modifyProtocolVar (.objectsVar) (HM.insert (genericObjectId object) someObject)
  pure object


-- | Create an object. The caller is responsible for sending the 'NewId' immediately (exactly once and before using the
-- object).
--
-- For implementing wl_registry.bind (which is low-level protocol functionality, but which depends on generated code).
bindNewObject
  :: forall i. IsInterfaceSide 'Client i
  => ProtocolHandle 'Client
  -> Version
  -> Maybe (MessageHandler 'Client i)
  -> STM (Object 'Client i, GenericNewId)
bindNewObject protocol version messageHandler = runProtocolM protocol do
  (object, NewId (ObjectId newId)) <- newObject messageHandler
  pure (object, GenericNewId (interfaceName @i) version newId)


fromSomeObject
  :: forall s i m. IsInterfaceSide s i
  => SomeObject s -> Either String (Object s i)
fromSomeObject (SomeObject someObject) =
  case cast someObject of
    Nothing -> Left $ mconcat ["Expected object with type ",
      toString (interfaceName @i), ", but object has type ",
      toString (objectInterfaceName someObject)]
    Just object -> pure object


lookupObject
  :: forall s i. IsInterfaceSide s i
  => ObjectId (InterfaceName i)
  -> ProtocolM s (Either String (Object s i))
lookupObject oId = do
  objects <- readProtocolVar (.objectsVar)
  pure case HM.lookup (toGenericObjectId oId) objects of
    Nothing -> Left $ mconcat ["No object with id ", show oId, " is registered"]
    Just someObject ->
      case fromSomeObject someObject of
        Left err -> Left err
        Right object -> pure object

-- | Lookup an object for an id or throw a `ProtocolException`. To be used from generated code when receiving an object
-- id.
getObject
  :: forall s i. IsInterfaceSide s i
  => ObjectId (InterfaceName i)
  -> ProtocolM s (Object s i)
getObject oId = either (throwM . ProtocolException . ("Received invalid object id: " <>)) pure =<< lookupObject oId

-- | Lookup an object for an id or throw a `ProtocolException`. To be used from generated code when receiving an object
-- id.
getNullableObject
  :: forall s i. IsInterfaceSide s i
  => ObjectId (InterfaceName i)
  -> ProtocolM s (Maybe (Object s i))
getNullableObject (ObjectId 0) = pure Nothing
getNullableObject oId = Just <$> getObject oId



-- | Handle a wl_display.error message. Because this is part of the core protocol but generated from the xml it has to
-- be called from the client module.
handleWlDisplayError :: ProtocolHandle 'Client -> GenericObjectId -> Word32 -> WlString -> STM ()
handleWlDisplayError _protocol oId code message = throwM $ ServerError code (toString message)

-- | Handle a wl_display.delete_id message. Because this is part of the core protocol but generated from the xml it has
-- to be called from the client module.
handleWlDisplayDeleteId :: ProtocolHandle 'Client -> Word32 -> STM ()
handleWlDisplayDeleteId protocol oId = runProtocolM protocol do
  -- TODO call destructor
  modifyProtocolVar (.objectsVar) $ HM.delete (GenericObjectId oId)


handleDestructor :: IsInterfaceSide s i => Object s i -> ProtocolM s ()
handleDestructor object = do
  traceM $ "Handling destructor for " <> showObject object
  lift $ writeTVar object.destroyed True


checkObject :: IsInterface i => Object s i -> ProtocolM s (Either String ())
checkObject object = do
  -- TODO check if object belongs to current connection
  isActiveObject <- HM.member (genericObjectId object) <$> readProtocolVar (.objectsVar)
  pure
    if isActiveObject
      then pure ()
      else Left $ mconcat ["Object ", show object, " has been deleted"]


-- | Verify that an object can be used as an argument (throws otherwise) and return its id.
objectWireArgument :: IsInterface i => Object s i -> ProtocolM s (ObjectId (InterfaceName i))
objectWireArgument object = do
  checkObject object >>= \case
    Left msg -> throwM $ ProtocolUsageError $ "Tried to send a reference to an invalid object: " <> msg
    Right () -> pure object.objectId

-- | Verify that an object can be used as an argument (throws otherwise) and return its id.
nullableObjectWireArgument :: IsInterface i => Maybe (Object s i) -> ProtocolM s (ObjectId (InterfaceName i))
nullableObjectWireArgument Nothing = pure (ObjectId 0)
nullableObjectWireArgument (Just object) = objectWireArgument object


-- | Sends a message, for use in generated code.
sendMessage :: forall s i. IsInterfaceSide s i => Object s i -> WireUp s i -> ProtocolM s ()
sendMessage object message = do
  checkObject object >>= \case
    Left msg -> throwM $ ProtocolUsageError $ "Tried to send message to an invalid object: " <> msg
    Right () -> pure ()

  (opcode, MessagePart putBody bodyLength fds) <- either throwM pure $ putWireUp object message

  when (bodyLength > fromIntegral (maxBound :: Word16)) $
    throwM $ ProtocolUsageError $ "Tried to send message larger than 2^16 bytes"

  traceM $ "-> " <> showObjectMessage object message
  sendRawMessage (putHeader opcode (8 + bodyLength) >> putBody) fds
  where
    oId = genericObjectId object
    (GenericObjectId objectIdWord) = genericObjectId object
    putHeader :: Opcode -> Int -> Put
    putHeader opcode msgSize = do
      putWord32host objectIdWord
      putWord32host $ (fromIntegral msgSize `shiftL` 16) .|. fromIntegral opcode

enterObject :: forall s i a. Object s i -> ProtocolM s a -> STM a
enterObject object action = runProtocolM object.objectProtocol action


receiveMessages :: IsSide s => ProtocolM s ()
receiveMessages = receiveRawMessage >>= \case
  Nothing -> pure ()
  Just rawMessage -> do
    handleRawMessage rawMessage
    receiveMessages

handleRawMessage :: forall s. RawMessage -> ProtocolM s ()
handleRawMessage (oId, opcode, body) = do
  objects <- readProtocolVar (.objectsVar)
  case HM.lookup oId objects of
    Nothing -> throwM $ ProtocolException $ "Received message with invalid object id " <> show oId
    Just (SomeObject object) ->
      case runGetOrFail (getMessageAction object) body of
        Left (_, _, message) ->
          throwM $ ParserFailed (describeDownMessage object opcode body) message
        Right ("", _, result) -> result
        Right (leftovers, _, _) ->
          throwM $ ParserFailed (describeDownMessage object opcode body) (show (BSL.length leftovers) <> "B not parsed")
  where
    getMessageAction
      :: forall i. IsInterfaceSide s i
      => Object s i
      -> Get (ProtocolM s ())
    getMessageAction object = do
      verifyMessage <- getWireDown object opcode
      pure do
        message <- verifyMessage
        traceM $ "<- " <> showObjectMessage object message
        messageHandler <- lift $ getMessageHandler object
        handleMessage @s @i object messageHandler message

type RawMessage = (GenericObjectId, Opcode, BSL.ByteString)

receiveRawMessage :: forall s. ProtocolM s (Maybe RawMessage)
receiveRawMessage = do
  (result, nextDecoder) <- checkDecoder =<< readProtocolVar (.inboxDecoderVar)
  writeProtocolVar (.inboxDecoderVar) nextDecoder
  pure result
  where
    checkDecoder
      :: Decoder RawMessage
      -> ProtocolM s (Maybe RawMessage, Decoder RawMessage)
    checkDecoder (Fail _ _ message) = throwM (ParserFailed "RawMessage" message)
    checkDecoder x@(Partial _) = pure (Nothing, x)
    checkDecoder (Done leftovers _ result) = pure (Just result, pushChunk (runGetIncremental getRawMessage) leftovers)


getRawMessage :: Get RawMessage
getRawMessage = do
  oId <- GenericObjectId <$> getWord32host
  sizeAndOpcode <- getWord32host
  let
    size = fromIntegral (sizeAndOpcode `shiftR` 16) - 8
    opcode = fromIntegral (sizeAndOpcode .&. 0xFFFF)
  body <- getLazyByteString size
  pure (oId, opcode, body)

getWaylandString :: Get BS.ByteString
getWaylandString = do
  Just (string, 0) <- BS.unsnoc <$> getWaylandArray
  pure string

getWaylandArray :: Get BS.ByteString
getWaylandArray = do
  size <- getWord32host
  array <- getByteString (fromIntegral size)
  skipPadding
  pure array

putWaylandString :: MonadThrow m => BS.ByteString -> m MessagePart
putWaylandString blob = do
  when (len > fromIntegral (maxBound :: Word16)) $
    throwM $ ProtocolUsageError $ "Tried to send string larger than 2^16 bytes"

  pure $ MessagePart putBlob (4 + len + pad) mempty
  where
    -- Total data length including null byte
    len = BS.length blob + 1
    -- Padding length
    pad = padding len
    putBlob = do
      putWord32host (fromIntegral len)
      putByteString blob
      putWord8 0
      replicateM_ pad (putWord8 0)

putWaylandArray :: MonadThrow m => BS.ByteString -> m MessagePart
putWaylandArray blob = do
  when (len > fromIntegral (maxBound :: Word16)) $
    throwM $ ProtocolUsageError $ "Tried to send array larger than 2^16 bytes"

  pure $ MessagePart putBlob (4 + len + pad) mempty
  where
    -- Total data length without padding
    len = BS.length blob
    -- Padding length
    pad = padding len
    putBlob = do
      putWord32host (fromIntegral len)
      putByteString blob
      replicateM_ pad (putWord8 0)


skipPadding :: Get ()
skipPadding = do
  bytes <- bytesRead
  skip $ fromIntegral (padding bytes)

padding :: Integral a => a -> a
padding size = ((4 - (size `mod` 4)) `mod` 4)


sendRawMessage :: Put -> Seq Fd -> ProtocolM s ()
sendRawMessage x fds = do
  modifyProtocolVar (.outboxVar) (Just . maybe x (<> x))
  modifyProtocolVar (.outboxFdsVar) (<> fds)