From f5aafcaf3881b6a5461a2bd8061d312cf6839f33 Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Mon, 6 Sep 2021 22:42:53 +0200
Subject: [PATCH] Add types for bidirectional connection

---
 src/Quasar/Wayland/Core.hs | 264 +++++++++++++++++++++++++++----------
 1 file changed, 196 insertions(+), 68 deletions(-)

diff --git a/src/Quasar/Wayland/Core.hs b/src/Quasar/Wayland/Core.hs
index 997db9e..0f015ad 100644
--- a/src/Quasar/Wayland/Core.hs
+++ b/src/Quasar/Wayland/Core.hs
@@ -1,20 +1,30 @@
 module Quasar.Wayland.Core (
   ObjectId,
   Opcode,
+  IsInterface(..),
+  Side(..),
+  Object,
+  IsSomeObject,
+  IsMessage(..),
   ProtocolState,
   ClientProtocolState,
-  initialClientProtocolState,
-  --ServerProtocolState,
-  --initialServerProtocolState,
+  ServerProtocolState,
+  ClientCallback,
+  ServerCallback,
+  Callback(..),
   Request,
   Event,
+  ProtocolStep,
   initialProtocolState,
+  sendMessage,
   feedInput,
-  takeOutbox,
+  setException,
 ) where
 
-import Control.Monad.State (State)
+import Control.Monad.Catch
+import Control.Monad.State (StateT, runStateT, state, modify)
 import Control.Monad.State qualified as State
+import Data.Binary
 import Data.Binary.Get
 import Data.Binary.Put
 import Data.Bits ((.&.), (.|.), shiftL, shiftR)
@@ -23,17 +33,62 @@ import Data.ByteString qualified as BS
 import Data.ByteString.Lazy qualified as BSL
 import Data.HashMap.Strict (HashMap)
 import Data.HashMap.Strict qualified as HM
+import Data.Maybe (isJust)
+import Data.Void (absurd)
 import Quasar.Prelude
 
 
 type ObjectId = Word32
-type ObjectType = String
 type Opcode = Word16
 
-data Object = Object {
-  objectId :: ObjectId,
-  objectType :: ObjectType
-}
+
+-- | A wayland interface
+class (Binary (TRequest a), Binary (TEvent a)) => IsInterface a where
+  type TRequest a
+  type TEvent a
+  interfaceName :: String
+
+class IsInterface a => IsObject (s :: Side) a where
+  type Up s a
+  type Down s a
+
+data Side = Client | Server
+
+data Object s m a = IsInterface a => Object ObjectId (Callback s m a)
+
+instance IsInterface a => IsObject 'Client a where
+  type Up 'Client a = TRequest a
+  type Down 'Client a = TEvent a
+
+instance IsInterface a => IsObject 'Server a where
+  type Up 'Server a = TEvent a
+  type Down 'Server a = TRequest a
+
+instance forall s m a. IsInterface a => IsSomeObject (Object s m a) where
+  objectId (Object oId _) = oId
+  objectInterfaceName _ = interfaceName @a
+
+mkObject :: forall s m a. IsInterface a => ObjectId -> Callback s m a -> Object s m a
+mkObject oId callback = Object @s @m @a oId callback
+
+
+class IsSomeObject a where
+  objectId :: a -> ObjectId
+  objectInterfaceName :: a -> String
+
+-- | Wayland object quantification wrapper
+data SomeObject = forall a. IsSomeObject a => SomeObject a
+
+instance IsSomeObject SomeObject where
+  objectId (SomeObject object) = objectId object
+  objectInterfaceName (SomeObject object) = objectInterfaceName object
+
+
+class IsMessage a where
+  messageName :: a -> String
+
+instance IsMessage Void where
+  messageName = absurd
 
 
 data Argument
@@ -61,16 +116,16 @@ putArgument (NewIdArgument x) = putWord32host x
 putArgument _ = undefined
 
 
-type ClientProtocolState = ProtocolState Request Event
-type ServerProtocolState = ProtocolState Event Request
+type ClientProtocolState m = ProtocolState 'Client m
+type ServerProtocolState m = ProtocolState 'Server m
 
-data ProtocolState up down = ProtocolState {
+data ProtocolState (s :: Side) m = ProtocolState {
+  protocolException :: Maybe SomeException,
   bytesReceived :: Word64,
   bytesSent :: Word64,
-  parser :: Decoder down,
-  inboxDecoder :: Decoder down,
+  inboxDecoder :: Decoder (ObjectId, Opcode, BSL.ByteString),
   outbox :: Maybe Put,
-  objects :: HashMap ObjectId Object
+  objects :: HashMap ObjectId SomeObject
 }
 
 data Request = Request ObjectId Opcode BSL.ByteString
@@ -78,71 +133,148 @@ data Request = Request ObjectId Opcode BSL.ByteString
 data Event = Event ObjectId Opcode (Either BSL.ByteString (Word32, BSL.ByteString, Word32))
   deriving stock Show
 
-initialClientProtocolState :: ClientProtocolState
-initialClientProtocolState = initialProtocolState decodeEvent
-
-initialProtocolState :: Get down -> ProtocolState up down
-initialProtocolState downGet = sendInitialMessage ProtocolState {
-  bytesReceived = 0,
-  bytesSent = 0,
-  parser = runGetIncremental downGet,
-  inboxDecoder = runGetIncremental downGet,
-  outbox = Nothing,
-  objects = HM.singleton 1 (Object 1 "wl_display")
+
+type ClientCallback m a = Callback 'Client m a
+type ServerCallback m a = Callback 'Server m a
+
+data Callback s m a = Callback {
+  messageCallback :: Object s m a -> Down s a -> StateT (ProtocolState s m) m ()
 }
 
-sendInitialMessage :: ProtocolState up down -> ProtocolState up down
-sendInitialMessage = sendMessage 1 1 [NewIdArgument 2]
+-- * Exceptions
+
+data CallbackFailed = CallbackFailed SomeException
+  deriving stock Show
+  deriving anyclass Exception
+
+data ParserFailed = ParserFailed String
+  deriving stock Show
+  deriving anyclass Exception
+
+-- * Monad plumbing
 
-feedInput :: forall up down. ByteString -> ProtocolState up down -> ([down], ProtocolState up down)
-feedInput bytes = State.runState do
-  State.modify (receive bytes)
-  go
+type ProtocolStep s m a = ProtocolState s m -> m (Either SomeException a, Maybe BSL.ByteString, ProtocolState s m)
+
+protocolStep :: forall s m a. MonadCatch m => StateT (ProtocolState s m) m a -> ProtocolStep s m a
+protocolStep action inState = do
+  mapM_ throwM inState.protocolException
+  (result, (outbox, outState)) <- fmap takeOutbox . storeExceptionIfFailed <$> runStateT (try action) inState
+  pure (result, outbox, outState)
   where
-    go :: State (ProtocolState up down) [down]
-    go = State.state takeDownMsg >>= \case
-      Nothing -> pure []
-      Just msg -> (msg :) <$> go
+    storeExceptionIfFailed :: (Either SomeException a, ProtocolState s m) -> (Either SomeException a, ProtocolState s m)
+    storeExceptionIfFailed (Left ex, st) = (Left ex, setException ex st)
+    storeExceptionIfFailed x = x
+    setException :: (MonadCatch m, Exception e) => e -> (ProtocolState s m) -> (ProtocolState s m)
+    setException ex st =
+      if isJust st.protocolException
+        then st
+        else st{protocolException = Just (toException ex)}
 
+-- * Exported functions
 
-receive :: forall up down. ByteString -> ProtocolState up down -> ProtocolState up down
-receive bytes state = state {
-  bytesReceived = state.bytesReceived + fromIntegral (BS.length bytes),
-  inboxDecoder = pushChunk state.inboxDecoder bytes
-}
+initialProtocolState
+  :: forall wl_display s m. IsInterface wl_display
+  => Callback s m wl_display
+  -> ProtocolState s m
+initialProtocolState wlDisplayCallback = sendInitialMessage initialState
+  where
+    wlDisplay :: Object s m wl_display
+    wlDisplay = mkObject 1 wlDisplayCallback
+    initialState :: ProtocolState s m
+    initialState = ProtocolState {
+      protocolException = Nothing,
+      bytesReceived = 0,
+      bytesSent = 0,
+      inboxDecoder = runGetIncremental getRawMessage,
+      outbox = Nothing,
+      objects = HM.singleton 1 (SomeObject wlDisplay)
+    }
+
+-- | Feed the protocol newly received data
+feedInput :: MonadCatch m => ByteString -> ProtocolStep s m ()
+feedInput bytes = protocolStep do
+  feed
+  runCallbacks
+  where
+    feed = modify \st -> st {
+      bytesReceived = st.bytesReceived + fromIntegral (BS.length bytes),
+      inboxDecoder = pushChunk st.inboxDecoder bytes
+    }
+
+sendMessage :: MonadCatch m => Object s m a -> Up s a -> ProtocolStep s m ()
+sendMessage object message = protocolStep do
+  undefined message
+  runCallbacks
+
+setException :: (MonadCatch m, Exception e) => e -> ProtocolStep s m ()
+setException ex = protocolStep do
+  modify \st -> st{protocolException = Just (toException ex)}
+
+-- * Internals
+
+-- | Take data that has to be sent (if available)
+takeOutbox :: MonadCatch m => ProtocolState s m ->  (Maybe BSL.ByteString, ProtocolState s m)
+takeOutbox st = (runPut <$> st.outbox, st{outbox = Nothing})
+
+
+sendInitialMessage :: ProtocolState s m -> ProtocolState s m
+sendInitialMessage = sendMessageInternal 1 1 [NewIdArgument 2]
 
-takeDownMsg :: forall up down. ProtocolState up down -> (Maybe down, ProtocolState up down)
-takeDownMsg state = (result, state{inboxDecoder = newDecoder})
+runCallbacks :: MonadCatch m => StateT (ProtocolState s m) m ()
+runCallbacks = receiveRawMessage >>= \case
+  Nothing -> pure ()
+  Just message -> do
+    traceM $ show message
+    runCallbacks
+
+
+type RawMessage = (ObjectId, Opcode, BSL.ByteString)
+
+getRawMessage :: Get RawMessage
+getRawMessage = do
+  oId <- getWord32host
+  sizeAndOpcode <- getWord32host
+  let
+    size = fromIntegral (sizeAndOpcode `shiftR` 16) - 8
+    opcode = fromIntegral (sizeAndOpcode .&. 0xFFFF)
+  body <- getLazyByteString size
+  pure (oId, opcode, body)
+
+receiveRawMessage :: MonadCatch m => StateT (ProtocolState s m) m (Maybe RawMessage)
+receiveRawMessage = do
+  st <- State.get
+  (result, newDecoder) <- checkDecoder st.inboxDecoder
+  State.put st{inboxDecoder = newDecoder}
+
+  pure result
   where
-    result :: Maybe down
-    newDecoder :: Decoder down
-    (result, newDecoder) = checkDecoder state.inboxDecoder
-    checkDecoder :: Decoder down -> (Maybe down, Decoder down)
-    checkDecoder (Fail _ _ _) = undefined
-    checkDecoder x@(Partial _) = (Nothing, x)
-    checkDecoder (Done leftovers _ result) = (Just result, pushChunk state.parser leftovers)
+    checkDecoder :: MonadCatch m => Decoder RawMessage -> StateT (ProtocolState s m) m (Maybe RawMessage, Decoder RawMessage)
+    checkDecoder d@(Fail _ _ message) = throwM (ParserFailed message)
+    checkDecoder x@(Partial _) = pure (Nothing, x)
+    checkDecoder (Done leftovers _ result) = pure (Just result, pushChunk (runGetIncremental getRawMessage) leftovers)
 
 
 decodeEvent :: Get Event
 decodeEvent = do
-  objectId <- getWord32host
+  oId <- getWord32host
   sizeAndOpcode <- getWord32host
   let
     size = fromIntegral (sizeAndOpcode `shiftR` 16) - 8
     opcode = fromIntegral (sizeAndOpcode .&. 0xFFFF)
-  body <- if (objectId == 2 && opcode == 0)
+  body <- if (oId == 2 && opcode == 0)
              then Right <$> parseGlobal
              else Left <$> getLazyByteString size <* skipPadding
-  pure $ Event objectId opcode body
+  pure $ Event oId opcode body
   where
     parseGlobal :: Get (Word32, BSL.ByteString, Word32)
     parseGlobal = (,,) <$> getWord32host <*> getWaylandString <*> getWord32host
-    getWaylandString :: Get BSL.ByteString
-    getWaylandString = do
-      size <- getWord32host
-      Just (string, 0) <- BSL.unsnoc <$> getLazyByteString (fromIntegral size)
-      skipPadding
-      pure string
+
+getWaylandString :: Get BSL.ByteString
+getWaylandString = do
+  size <- getWord32host
+  Just (string, 0) <- BSL.unsnoc <$> getLazyByteString (fromIntegral size)
+  skipPadding
+  pure string
 
 skipPadding :: Get ()
 skipPadding = do
@@ -150,9 +282,9 @@ skipPadding = do
   skip $ fromIntegral ((4 - (bytes `mod` 4)) `mod` 4)
 
 
-sendMessage :: ObjectId -> Opcode -> [Argument] -> ProtocolState up down -> ProtocolState up down
-sendMessage objectId opcode args = sendRaw do
-  putWord32host objectId
+sendMessageInternal :: ObjectId -> Opcode -> [Argument] -> ProtocolState s m -> ProtocolState s m
+sendMessageInternal oId opcode args = sendRaw do
+  putWord32host oId
   putWord32host $ (fromIntegral msgSize `shiftL` 16) .|. fromIntegral opcode
   mapM_ putArgument args
   -- TODO padding
@@ -162,11 +294,7 @@ sendMessage objectId opcode args = sendRaw do
     msgSizeInteger :: Integer
     msgSizeInteger = foldr ((+) . (fromIntegral . argumentSize)) 8 args :: Integer
 
-sendRaw :: Put -> ProtocolState up down -> ProtocolState up down
+sendRaw :: Put -> ProtocolState s m -> ProtocolState s m
 sendRaw put oldState = oldState {
   outbox = Just (maybe put (<> put) oldState.outbox)
 }
-
-takeOutbox :: ProtocolState up down -> (Maybe BSL.ByteString, ProtocolState up down)
-takeOutbox state = (runPut <$> state.outbox, state{outbox = Nothing})
-akeOutbox state = (runPut <$> state.outbox, state{outbox = Nothing})
-- 
GitLab