From b43d66a94c48d9d1091467ca6dc416c7f1698cca Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Thu, 16 Sep 2021 14:39:36 +0200
Subject: [PATCH] Pin the protocol monad to STM

I really like the idea of a pure wire protocol implementation and might
revisit that later, but to get an MVP it's much easier to use STM in the
core, since all interface impementations are going to use it anyway.
---
 src/Quasar/Wayland/Client.hs        |   4 +-
 src/Quasar/Wayland/Connection.hs    |   8 +-
 src/Quasar/Wayland/Protocol.hs      |  11 +--
 src/Quasar/Wayland/Protocol/Core.hs | 111 ++++++++++++++--------------
 4 files changed, 68 insertions(+), 66 deletions(-)

diff --git a/src/Quasar/Wayland/Client.hs b/src/Quasar/Wayland/Client.hs
index 6f8f2cf..2465482 100644
--- a/src/Quasar/Wayland/Client.hs
+++ b/src/Quasar/Wayland/Client.hs
@@ -18,7 +18,7 @@ import System.FilePath ((</>), isRelative)
 import Text.Read (readEither)
 
 
-data WaylandClient = WaylandClient (WaylandConnection 'Client) (Object 'Client STM I_wl_display)
+data WaylandClient = WaylandClient (WaylandConnection 'Client) (Object 'Client I_wl_display)
 
 instance IsResourceManager WaylandClient where
   toResourceManager (WaylandClient connection _) = toResourceManager connection
@@ -30,7 +30,7 @@ newWaylandClient :: MonadResourceManager m => Socket -> m WaylandClient
 newWaylandClient socket = do
   (connection, wlDisplay) <- newWaylandConnection @I_wl_display (traceCallback ignoreMessage) socket
 
-  (_wlRegistry, newId) <- stepProtocol connection $ newObject @'Client @STM @I_wl_registry (traceCallback ignoreMessage)
+  (_wlRegistry, newId) <- stepProtocol connection $ newObject @'Client @I_wl_registry (traceCallback ignoreMessage)
   stepProtocol connection $ sendMessage wlDisplay $ R_wl_display_get_registry newId
   pure $ WaylandClient connection wlDisplay
 
diff --git a/src/Quasar/Wayland/Connection.hs b/src/Quasar/Wayland/Connection.hs
index f5b0b99..4a401e3 100644
--- a/src/Quasar/Wayland/Connection.hs
+++ b/src/Quasar/Wayland/Connection.hs
@@ -19,7 +19,7 @@ import Quasar.Wayland.Protocol.Generated
 
 
 data WaylandConnection s = WaylandConnection {
-  protocolStateVar :: TVar (ProtocolState s STM),
+  protocolStateVar :: TVar (ProtocolState s),
   outboxVar :: TMVar BSL.ByteString,
   socket :: Socket,
   resourceManager :: ResourceManager
@@ -37,9 +37,9 @@ data SocketClosed = SocketClosed
 
 newWaylandConnection
   :: forall wl_display wl_registry s m. (IsInterfaceSide s wl_display, MonadResourceManager m)
-  => Callback s STM wl_display
+  => Callback s wl_display
   -> Socket
-  -> m (WaylandConnection s, Object s STM wl_display)
+  -> m (WaylandConnection s, Object s wl_display)
 newWaylandConnection wlDisplayCallback socket = do
   protocolStateVar <- liftIO $ newTVarIO protocolState
   outboxVar <- liftIO newEmptyTMVarIO
@@ -64,7 +64,7 @@ newWaylandConnection wlDisplayCallback socket = do
   where
     (protocolState, wlDisplay) = initialProtocolState wlDisplayCallback
 
-stepProtocol :: forall s m a. MonadIO m => WaylandConnection s -> ProtocolStep s STM a -> m a
+stepProtocol :: forall s m a. MonadIO m => WaylandConnection s -> ProtocolStep s a -> m a
 stepProtocol connection step = liftIO do
   result <- atomically do
     oldState <- readTVar connection.protocolStateVar
diff --git a/src/Quasar/Wayland/Protocol.hs b/src/Quasar/Wayland/Protocol.hs
index 03dd836..7e8d34b 100644
--- a/src/Quasar/Wayland/Protocol.hs
+++ b/src/Quasar/Wayland/Protocol.hs
@@ -3,6 +3,7 @@ module Quasar.Wayland.Protocol (
   createClientStateWithRegistry
 ) where
 
+import Control.Concurrent.STM
 import Control.Monad.Catch
 import Control.Monad.State (StateT, runStateT)
 import Data.ByteString.UTF8 (toString)
@@ -11,24 +12,24 @@ import Quasar.Wayland.Protocol.Core
 import Quasar.Wayland.Protocol.Generated
 
 
-createClientStateWithRegistry :: forall m. MonadCatch m => m (ProtocolState 'Client m)
+createClientStateWithRegistry :: STM (ProtocolState 'Client)
 createClientStateWithRegistry = do
   (wlRegistry, state') <- runStateT go initialState'
   pure state'
   where
     (initialState', wlDisplay) = initialProtocolState wlDisplayCallback
 
-    go :: ProtocolAction 'Client m (Object 'Client m I_wl_registry)
+    go :: ProtocolAction 'Client (Object 'Client I_wl_registry)
     go = do
-      (wlRegistry, newId) <- newObjectInternal @'Client @m @I_wl_registry (traceCallback ignoreMessage)
+      (wlRegistry, newId) <- newObjectInternal @'Client @I_wl_registry (traceCallback ignoreMessage)
       sendMessageInternal wlDisplay $ R_wl_display_get_registry newId
 
       pure wlRegistry
 
-    wlDisplayCallback :: forall m. (IsInterfaceSide 'Client I_wl_display, MonadCatch m) => Callback 'Client m I_wl_display
+    wlDisplayCallback :: IsInterfaceSide 'Client I_wl_display => Callback 'Client I_wl_display
     wlDisplayCallback = internalFnCallback handler
       where
-        handler :: Object 'Client m I_wl_display -> E_wl_display -> ProtocolAction 'Client m ()
+        handler :: Object 'Client I_wl_display -> E_wl_display -> ProtocolAction 'Client ()
         -- TODO parse oId
         handler _ (E_wl_display_error oId code message) = throwM $ ServerError code (toString message)
         handler _ (E_wl_display_delete_id deletedId) = pure () -- TODO confirm delete
diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs
index a8e355a..73067ae 100644
--- a/src/Quasar/Wayland/Protocol/Core.hs
+++ b/src/Quasar/Wayland/Protocol/Core.hs
@@ -41,6 +41,7 @@ module Quasar.Wayland.Protocol.Core (
   invalidOpcode,
 ) where
 
+import Control.Concurrent.STM
 import Control.Monad (replicateM_)
 import Control.Monad.Catch
 import Control.Monad.State (StateT, runStateT)
@@ -213,21 +214,21 @@ class (
   => IsInterfaceSide (s :: Side) i
 
 
-getDown :: forall s m i. IsInterfaceSide s i => Object s m i -> Opcode -> Get (Down s i)
+getDown :: forall s i. IsInterfaceSide s i => Object s i -> Opcode -> Get (Down s i)
 getDown = getMessage @(Down s i)
 
-putUp :: forall s m i. IsInterfaceSide s i => Object s m i -> Up s i -> PutM Opcode
+putUp :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> PutM Opcode
 putUp _ = putMessage @(Up s i)
 
 
-class IsInterfaceSide s i => IsInterfaceHandler s m i a where
-  handleMessage :: a -> Object s m i -> Down s i -> ProtocolAction s m ()
+class IsInterfaceSide s i => IsInterfaceHandler s i a where
+  handleMessage :: a -> Object s i -> Down s i -> ProtocolAction s ()
 
 
 -- | Data kind
 data Side = Client | Server
 
-data Object s m i = IsInterfaceSide s i => Object GenericObjectId (Callback s m i)
+data Object s i = IsInterfaceSide s i => Object GenericObjectId (Callback s i)
 
 class IsObject a where
   objectId :: a -> GenericObjectId
@@ -237,11 +238,11 @@ class IsObjectSide a where
   describeUpMessage :: a -> Opcode -> BSL.ByteString -> String
   describeDownMessage :: a -> Opcode -> BSL.ByteString -> String
 
-instance forall s m i. IsInterface i => IsObject (Object s m i) where
+instance forall s i. IsInterface i => IsObject (Object s i) where
   objectId (Object oId _) = oId
   objectInterfaceName _ = interfaceName @i
 
-instance forall s m i. IsInterfaceSide s i => IsObjectSide (Object s m i) where
+instance forall s i. IsInterfaceSide s i => IsObjectSide (Object s i) where
   describeUpMessage object opcode body =
     objectInterfaceName object <> "@" <> show (objectId object) <>
     "." <> fromMaybe "[invalidOpcode]" (opcodeName @(Up s i) opcode) <>
@@ -252,17 +253,17 @@ instance forall s m i. IsInterfaceSide s i => IsObjectSide (Object s m i) where
     " (" <> show (BSL.length body) <> "B)"
 
 -- | Wayland object quantification wrapper
-data SomeObject s m
-  = forall i. IsInterfaceSide s i => SomeObject (Object s m i)
+data SomeObject s
+  = forall i. IsInterfaceSide s i => SomeObject (Object s i)
   | UnknownObject String GenericObjectId
 
-instance IsObject (SomeObject s m) where
+instance IsObject (SomeObject s) where
   objectId (SomeObject object) = objectId object
   objectId (UnknownObject _ oId) = oId
   objectInterfaceName (SomeObject object) = objectInterfaceName object
   objectInterfaceName (UnknownObject interface _) = interface
 
-instance IsObjectSide (SomeObject s m) where
+instance IsObjectSide (SomeObject s) where
   describeUpMessage (SomeObject object) = describeUpMessage object
   describeUpMessage (UnknownObject interface oId) =
     \opcode body -> interface <> "@" <> show oId <> ".#" <> show opcode <>
@@ -275,7 +276,7 @@ instance IsObjectSide (SomeObject s m) where
 
 class (Eq a, Show a) => IsMessage a where
   opcodeName :: Opcode -> Maybe String
-  getMessage :: IsInterface i => Object s m i -> Opcode -> Get a
+  getMessage :: IsInterface i => Object s i -> Opcode -> Get a
   putMessage :: a -> PutM Opcode
 
 instance IsMessage Void where
@@ -283,7 +284,7 @@ instance IsMessage Void where
   getMessage = invalidOpcode
   putMessage = absurd
 
-invalidOpcode :: IsInterface i => Object s m i -> Opcode -> Get a
+invalidOpcode :: IsInterface i => Object s i -> Opcode -> Get a
 invalidOpcode object opcode =
   fail $ "Invalid opcode " <> show opcode <> " on " <> objectInterfaceName object <> "@" <> show (objectId object)
 
@@ -292,29 +293,29 @@ showObjectMessage object message =
   objectInterfaceName object <> "@" <> show (objectId object) <> "." <> show message
 
 
-data ProtocolState (s :: Side) m = ProtocolState {
+data ProtocolState (s :: Side) = ProtocolState {
   protocolException :: Maybe SomeException,
   bytesReceived :: !Int64,
   bytesSent :: !Int64,
   inboxDecoder :: Decoder RawMessage,
   outbox :: Maybe Put,
-  objects :: HashMap GenericObjectId (SomeObject s m),
+  objects :: HashMap GenericObjectId (SomeObject s),
   nextId :: Word32
 }
 
 
-data Callback s m i = forall a. IsInterfaceHandler s m i a => Callback a
+data Callback s i = forall a. IsInterfaceHandler s i a => Callback a
 
-instance IsInterfaceSide s i => IsInterfaceHandler s m i (Callback s m i) where
+instance IsInterfaceSide s i => IsInterfaceHandler s i (Callback s i) where
   handleMessage (Callback callback) = handleMessage callback
 
 
-data LowLevelCallback s m i = IsInterfaceSide s i => FnCallback (Object s m i -> Down s i -> ProtocolAction s m ())
+data LowLevelCallback s i = IsInterfaceSide s i => FnCallback (Object s i -> Down s i -> ProtocolAction s ())
 
-instance IsInterfaceSide s i => IsInterfaceHandler s m i (LowLevelCallback s m i) where
+instance IsInterfaceSide s i => IsInterfaceHandler s i (LowLevelCallback s i) where
   handleMessage (FnCallback fn) object msg = fn object msg
 
-internalFnCallback :: IsInterfaceSide s i => (Object s m i -> Down s i -> ProtocolAction s m ()) -> Callback s m i
+internalFnCallback :: IsInterfaceSide s i => (Object s i -> Down s i -> ProtocolAction s ()) -> Callback s i
 internalFnCallback = Callback . FnCallback
 
 
@@ -327,13 +328,13 @@ internalFnCallback = Callback . FnCallback
 -- trace message.
 --
 -- Uses `traceM` internally.
-traceCallback :: (IsInterfaceSide 'Client i, Monad m) => Callback 'Client m i -> Callback 'Client m i
+traceCallback :: IsInterfaceSide 'Client i => Callback 'Client i -> Callback 'Client i
 traceCallback next = internalFnCallback \object message -> do
   traceM $ "<- " <> showObjectMessage object message
   handleMessage next object message
 
 -- | A `Callback` that ignores all messages. Intended for development purposes, e.g. together with `traceCallback`.
-ignoreMessage :: (IsInterfaceSide 'Client i, Monad m) => Callback 'Client m i
+ignoreMessage :: IsInterfaceSide 'Client i => Callback 'Client i
 ignoreMessage = internalFnCallback \_ _ -> pure ()
 
 -- * Exceptions
@@ -360,21 +361,21 @@ data ServerError = ServerError Word32 String
 
 -- * Monad plumbing
 
-type ProtocolStep s m a = ProtocolState s m -> m (Either SomeException a, Maybe BSL.ByteString, ProtocolState s m)
+type ProtocolStep s a = ProtocolState s -> STM (Either SomeException a, Maybe BSL.ByteString, ProtocolState s)
 
 -- Must not be exported. 'ProtocolStep' ensures proper protocol failure in case of exceptions.
-type ProtocolAction s m a = StateT (ProtocolState s m) m a
+type ProtocolAction s a = StateT (ProtocolState s) STM a
 
-protocolStep :: forall s m a. MonadCatch m => ProtocolAction s m a -> ProtocolStep s m a
+protocolStep :: forall s a. ProtocolAction s a -> ProtocolStep s a
 protocolStep action inState = do
   mapM_ throwM inState.protocolException
   (result, (outbox, outState)) <- fmap takeOutbox . storeExceptionIfFailed <$> runStateT (try action) inState
   pure (result, outbox, outState)
   where
-    storeExceptionIfFailed :: (Either SomeException a, ProtocolState s m) -> (Either SomeException a, ProtocolState s m)
+    storeExceptionIfFailed :: (Either SomeException a, ProtocolState s) -> (Either SomeException a, ProtocolState s)
     storeExceptionIfFailed (Left ex, st) = (Left ex, setException' ex st)
     storeExceptionIfFailed x = x
-    setException' :: Exception e => e -> (ProtocolState s m) -> (ProtocolState s m)
+    setException' :: Exception e => e -> ProtocolState s -> ProtocolState s
     setException' ex st =
       if isJust st.protocolException
         then st
@@ -384,14 +385,14 @@ protocolStep action inState = do
 -- * Exported functions
 
 initialProtocolState
-  :: forall wl_display wl_registry s m. IsInterfaceSide s wl_display
-  => Callback s m wl_display
-  -> (ProtocolState s m, Object s m wl_display)
+  :: forall wl_display wl_registry s. (IsInterfaceSide s wl_display)
+  => Callback s wl_display
+  -> (ProtocolState s, Object s wl_display)
 initialProtocolState wlDisplayCallback = (initialState, wlDisplay)
   where
-    wlDisplay :: Object s m wl_display
+    wlDisplay :: Object s wl_display
     wlDisplay = Object 1 wlDisplayCallback
-    initialState :: ProtocolState s m
+    initialState :: ProtocolState s
     initialState = ProtocolState {
       protocolException = Nothing,
       bytesReceived = 0,
@@ -403,7 +404,7 @@ initialProtocolState wlDisplayCallback = (initialState, wlDisplay)
     }
 
 -- | Feed the protocol newly received data
-feedInput :: (IsSide s, MonadCatch m) => ByteString -> ProtocolStep s m ()
+feedInput :: IsSide s => ByteString -> ProtocolStep s ()
 feedInput bytes = protocolStep do
   feed
   receiveMessages
@@ -413,29 +414,29 @@ feedInput bytes = protocolStep do
       inboxDecoder = pushChunk st.inboxDecoder bytes
     }
 
-setException :: (MonadCatch m, Exception e) => e -> ProtocolStep s m ()
+setException :: Exception e => e -> ProtocolStep s ()
 setException ex = protocolStep do
   State.modify \st -> st{protocolException = Just (toException ex)}
 
 
 -- | Create an object. The caller is responsible for sending the 'NewId' exactly once before using the object.
 newObject
-  :: forall s m i. (IsInterfaceSide s i, MonadCatch m)
-  => Callback s m i
-  -> ProtocolStep s m (Object s m i, NewId (InterfaceName i))
+  :: forall s i. IsInterfaceSide s i
+  => Callback s i
+  -> ProtocolStep s (Object s i, NewId (InterfaceName i))
 newObject callback = protocolStep $ newObjectInternal callback
 
 newObjectInternal
-  :: forall s m i. (IsInterfaceSide s i, MonadCatch m)
-  => Callback s m i
-  -> ProtocolAction s m (Object s m i, NewId (InterfaceName i))
+  :: forall s i. IsInterfaceSide s i
+  => Callback s i
+  -> ProtocolAction s (Object s i, NewId (InterfaceName i))
 newObjectInternal callback = do
-  genOId <- allocateObjectId @s @m
+  genOId <- allocateObjectId @s
   let oId = NewId @(InterfaceName i) genOId
   object <- newObjectFromId oId callback
   pure (object, oId)
   where
-    allocateObjectId :: forall s m. (IsSide s, MonadCatch m) => ProtocolAction s m GenericObjectId
+    allocateObjectId :: forall s. IsSide s => ProtocolAction s GenericObjectId
     allocateObjectId = do
       st <- State.get
       let
@@ -447,10 +448,10 @@ newObjectInternal callback = do
       pure id
 
 newObjectFromId
-  :: forall s m i. (IsInterfaceSide s i, MonadCatch m)
+  :: forall s i. IsInterfaceSide s i
   => NewId (InterfaceName i)
-  -> Callback s m i
-  -> ProtocolAction s m (Object s m i)
+  -> Callback s i
+  -> ProtocolAction s (Object s i)
 newObjectFromId (NewId oId) callback = do
   let
     object = Object oId callback
@@ -460,10 +461,10 @@ newObjectFromId (NewId oId) callback = do
 
 
 -- | Sends a message without checking any ids or creating proxy objects objects.
-sendMessage :: forall s m i. (IsInterfaceSide s i, MonadCatch m) => Object s m i -> Up s i -> ProtocolStep s m ()
+sendMessage :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolStep s ()
 sendMessage object message = protocolStep $ sendMessageInternal object message
 
-sendMessageInternal :: forall s m i. (IsInterfaceSide s i, MonadCatch m) => Object s m i -> Up s i -> ProtocolAction s m ()
+sendMessageInternal :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolAction s ()
 sendMessageInternal object message = do
   traceM $ "-> " <> showObjectMessage object message
   sendRawMessage messageWithHeader
@@ -483,7 +484,7 @@ sendMessageInternal object message = do
     msgSizeInteger = 8 + fromIntegral (BSL.length body)
 
 -- | Take data that has to be sent (if available)
-takeOutbox :: ProtocolState s m ->  (Maybe BSL.ByteString, ProtocolState s m)
+takeOutbox :: ProtocolState s ->  (Maybe BSL.ByteString, ProtocolState s)
 takeOutbox st = (maybeOutboxData, st{outbox = Nothing, bytesSent = st.bytesSent + outboxNumBytes})
   where
     maybeOutboxData = if isJust st.protocolException then Nothing else outboxData
@@ -491,14 +492,14 @@ takeOutbox st = (maybeOutboxData, st{outbox = Nothing, bytesSent = st.bytesSent
     outboxNumBytes = maybe 0 BSL.length maybeOutboxData
 
 
-receiveMessages :: (IsSide s, MonadCatch m) => ProtocolAction s m ()
+receiveMessages :: IsSide s => ProtocolAction s ()
 receiveMessages = receiveRawMessage >>= \case
   Nothing -> pure ()
   Just rawMessage -> do
     handleRawMessage rawMessage
     receiveMessages
 
-handleRawMessage :: forall s m. MonadCatch m => RawMessage -> ProtocolAction s m ()
+handleRawMessage :: forall s. RawMessage -> ProtocolAction s ()
 handleRawMessage (oId, opcode, body) = do
   objects <- State.gets (.objects)
   case HM.lookup oId objects of
@@ -517,16 +518,16 @@ handleRawMessage (oId, opcode, body) = do
 
 getMessageAction
   :: IsInterfaceSide s i
-  => Object s m i
+  => Object s i
   -> Opcode
-  -> Get (ProtocolAction s m ())
+  -> Get (ProtocolAction s ())
 getMessageAction object@(Object _ objectHandler) opcode = do
   message <- getDown object opcode
   pure $ handleMessage objectHandler object message
 
 type RawMessage = (GenericObjectId, Opcode, BSL.ByteString)
 
-receiveRawMessage :: forall s m. MonadCatch m => ProtocolAction s m (Maybe RawMessage)
+receiveRawMessage :: forall s. ProtocolAction s (Maybe RawMessage)
 receiveRawMessage = do
   st <- State.get
   (result, newDecoder) <- checkDecoder st.inboxDecoder
@@ -535,7 +536,7 @@ receiveRawMessage = do
   where
     checkDecoder
       :: Decoder RawMessage
-      -> ProtocolAction s m (Maybe RawMessage, Decoder RawMessage)
+      -> ProtocolAction s (Maybe RawMessage, Decoder RawMessage)
     checkDecoder (Fail _ _ message) = throwM (ParserFailed "RawMessage" message)
     checkDecoder x@(Partial _) = pure (Nothing, x)
     checkDecoder (Done leftovers _ result) = pure (Just result, pushChunk (runGetIncremental getRawMessage) leftovers)
@@ -579,7 +580,7 @@ padding :: Integral a => a -> a
 padding size = ((4 - (size `mod` 4)) `mod` 4)
 
 
-sendRawMessage :: MonadCatch m => Put -> ProtocolAction s m ()
+sendRawMessage :: Put -> ProtocolAction s ()
 sendRawMessage x = State.modify \st -> st {
   outbox = Just (maybe x (<> x) st.outbox)
 }
-- 
GitLab