diff --git a/src/Quasar/Wayland/Client.hs b/src/Quasar/Wayland/Client.hs
index 246548240fb2da1d01c2a57c491e3c86711300f0..7e8bccbeb7141e8cb992e6b0ae1201c540965d2b 100644
--- a/src/Quasar/Wayland/Client.hs
+++ b/src/Quasar/Wayland/Client.hs
@@ -30,8 +30,8 @@ newWaylandClient :: MonadResourceManager m => Socket -> m WaylandClient
 newWaylandClient socket = do
   (connection, wlDisplay) <- newWaylandConnection @I_wl_display (traceCallback ignoreMessage) socket
 
-  (_wlRegistry, newId) <- stepProtocol connection $ newObject @'Client @I_wl_registry (traceCallback ignoreMessage)
-  stepProtocol connection $ sendMessage wlDisplay $ R_wl_display_get_registry newId
+  (_wlRegistry, newId) <- runProtocolM connection.protocolHandle $ newObject @'Client @I_wl_registry (traceCallback ignoreMessage)
+  runProtocolM connection.protocolHandle $ sendMessage wlDisplay $ R_wl_display_get_registry newId
   pure $ WaylandClient connection wlDisplay
 
 connectWaylandClient :: MonadResourceManager m => m WaylandClient
diff --git a/src/Quasar/Wayland/Connection.hs b/src/Quasar/Wayland/Connection.hs
index 4a401e3a8714bf40088628b7ac11bd978b2f98e8..6013a7ad8a23bf011e3bcf0185d1ea6cee8f9dbf 100644
--- a/src/Quasar/Wayland/Connection.hs
+++ b/src/Quasar/Wayland/Connection.hs
@@ -1,7 +1,6 @@
 module Quasar.Wayland.Connection (
-  WaylandConnection,
+  WaylandConnection(protocolHandle),
   newWaylandConnection,
-  stepProtocol,
 ) where
 
 import Control.Concurrent.STM
@@ -19,8 +18,7 @@ import Quasar.Wayland.Protocol.Generated
 
 
 data WaylandConnection s = WaylandConnection {
-  protocolStateVar :: TVar (ProtocolState s),
-  outboxVar :: TMVar BSL.ByteString,
+  protocolHandle :: ProtocolHandle s,
   socket :: Socket,
   resourceManager :: ResourceManager
 }
@@ -36,20 +34,18 @@ data SocketClosed = SocketClosed
   deriving anyclass Exception
 
 newWaylandConnection
-  :: forall wl_display wl_registry s m. (IsInterfaceSide s wl_display, MonadResourceManager m)
+  :: forall wl_display s m. (IsInterfaceSide s wl_display, MonadResourceManager m)
   => Callback s wl_display
   -> Socket
   -> m (WaylandConnection s, Object s wl_display)
 newWaylandConnection wlDisplayCallback socket = do
-  protocolStateVar <- liftIO $ newTVarIO protocolState
-  outboxVar <- liftIO newEmptyTMVarIO
+  (wlDisplay, protocolHandle) <- liftIO $ atomically $ initializeProtocol wlDisplayCallback pure
 
   resourceManager <- newResourceManager
 
   onResourceManager resourceManager do
     let connection = WaylandConnection {
-      protocolStateVar,
-      outboxVar,
+      protocolHandle,
       socket,
       resourceManager
     }
@@ -61,20 +57,6 @@ newWaylandConnection wlDisplayCallback socket = do
       connectionThread connection $ receiveThread connection
 
     pure (connection, wlDisplay)
-  where
-    (protocolState, wlDisplay) = initialProtocolState wlDisplayCallback
-
-stepProtocol :: forall s m a. MonadIO m => WaylandConnection s -> ProtocolStep s a -> m a
-stepProtocol connection step = liftIO do
-  result <- atomically do
-    oldState <- readTVar connection.protocolStateVar
-    (result, outBytes, newState) <- step oldState
-    writeTVar connection.protocolStateVar newState
-    mapM_ (putTMVar connection.outboxVar) outBytes
-    pure result
-  case result of
-    Left ex -> throwM (ex :: SomeException)
-    Right result -> pure result
 
 connectionThread :: MonadAsync m => WaylandConnection s -> IO () -> m ()
 connectionThread connection work = async_ $ liftIO $ work `catches` [ignoreCancelTask, handleAll]
@@ -84,7 +66,7 @@ connectionThread connection work = async_ $ liftIO $ work `catches` [ignoreCance
 
 sendThread :: WaylandConnection s -> IO ()
 sendThread connection = forever do
-  bytes <- atomically $ takeTMVar connection.outboxVar
+  bytes <- takeOutbox connection.protocolHandle
 
   traceIO $ "Sending " <> show (BSL.length bytes) <> " bytes"
   SocketL.sendAll connection.socket bytes
@@ -99,7 +81,7 @@ receiveThread connection = forever do
 
   traceIO $ "Received " <> show (BS.length bytes) <> " bytes"
 
-  stepProtocol connection $ feedInput bytes
+  feedInput connection.protocolHandle bytes
 
 closeConnection :: WaylandConnection s -> IO (Awaitable ())
 closeConnection connection = do
diff --git a/src/Quasar/Wayland/Protocol.hs b/src/Quasar/Wayland/Protocol.hs
index 7e8d34b3e638eef054976aa93c161d5301596700..e770a3db6b7356107e6d68b373d722fbd3654956 100644
--- a/src/Quasar/Wayland/Protocol.hs
+++ b/src/Quasar/Wayland/Protocol.hs
@@ -12,24 +12,23 @@ import Quasar.Wayland.Protocol.Core
 import Quasar.Wayland.Protocol.Generated
 
 
-createClientStateWithRegistry :: STM (ProtocolState 'Client)
+createClientStateWithRegistry :: STM (ProtocolHandle 'Client)
 createClientStateWithRegistry = do
-  (wlRegistry, state') <- runStateT go initialState'
-  pure state'
+  (wlRegistry, protocolHandle) <- initializeProtocol wlDisplayCallback createRegistry
+  pure protocolHandle
   where
-    (initialState', wlDisplay) = initialProtocolState wlDisplayCallback
-
-    go :: ProtocolAction 'Client (Object 'Client I_wl_registry)
-    go = do
-      (wlRegistry, newId) <- newObjectInternal @'Client @I_wl_registry (traceCallback ignoreMessage)
-      sendMessageInternal wlDisplay $ R_wl_display_get_registry newId
+    createRegistry :: Object 'Client I_wl_display -> ProtocolM 'Client (Object 'Client I_wl_registry)
+    createRegistry wlDisplay = do
+      (wlRegistry, newId) <- newObject @'Client @I_wl_registry (traceCallback ignoreMessage)
+      sendMessage wlDisplay $ R_wl_display_get_registry newId
 
       pure wlRegistry
 
     wlDisplayCallback :: IsInterfaceSide 'Client I_wl_display => Callback 'Client I_wl_display
     wlDisplayCallback = internalFnCallback handler
       where
-        handler :: Object 'Client I_wl_display -> E_wl_display -> ProtocolAction 'Client ()
+        -- | wl_display is specified to never change, so manually specifying the callback is safe
+        handler :: Object 'Client I_wl_display -> E_wl_display -> ProtocolM 'Client ()
         -- TODO parse oId
         handler _ (E_wl_display_error oId code message) = throwM $ ServerError code (toString message)
         handler _ (E_wl_display_delete_id deletedId) = pure () -- TODO confirm delete
diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs
index 73067ae272bcb797a7604d23647218432de61c3f..6cb097b157c03a127fca1bd85e5964f4cdb4d201 100644
--- a/src/Quasar/Wayland/Protocol/Core.hs
+++ b/src/Quasar/Wayland/Protocol/Core.hs
@@ -15,26 +15,36 @@ module Quasar.Wayland.Protocol.Core (
   Object,
   IsObject,
   IsMessage(..),
-  ProtocolState,
-  ProtocolAction,
+  ProtocolHandle,
+  ProtocolM,
+
+  -- * Protocol IO
+  initializeProtocol,
+  feedInput,
+  setException,
+  takeOutbox,
+  runProtocolM,
+
+  -- * Low-level protocol interaction
+  sendMessage,
+  newObject,
+
+  -- ** Callbacks
   Callback(..),
   internalFnCallback,
   traceCallback,
   ignoreMessage,
-  ProtocolStep,
-  initialProtocolState,
-  sendMessage,
-  newObject,
-  feedInput,
-  setException,
-  newObjectInternal,
-  sendMessageInternal,
-
-  showObjectMessage,
-  isNewId,
 
+  -- * Protocol exceptions
+  CallbackFailed(..),
+  ParserFailed(..),
+  ProtocolException(..),
+  MaximumIdReached(..),
   ServerError(..),
 
+  -- * TH utilities
+  isNewId,
+
   -- * Message decoder operations
   WireFormat(..),
   dropRemaining,
@@ -44,7 +54,7 @@ module Quasar.Wayland.Protocol.Core (
 import Control.Concurrent.STM
 import Control.Monad (replicateM_)
 import Control.Monad.Catch
-import Control.Monad.State (StateT, runStateT)
+import Control.Monad.Reader (ReaderT, runReaderT, ask, asks, lift)
 import Control.Monad.State qualified as State
 import Data.Binary
 import Data.Binary.Get
@@ -222,7 +232,7 @@ putUp _ = putMessage @(Up s i)
 
 
 class IsInterfaceSide s i => IsInterfaceHandler s i a where
-  handleMessage :: a -> Object s i -> Down s i -> ProtocolAction s ()
+  handleMessage :: a -> Object s i -> Down s i -> ProtocolM s ()
 
 
 -- | Data kind
@@ -293,29 +303,18 @@ showObjectMessage object message =
   objectInterfaceName object <> "@" <> show (objectId object) <> "." <> show message
 
 
-data ProtocolState (s :: Side) = ProtocolState {
-  protocolException :: Maybe SomeException,
-  bytesReceived :: !Int64,
-  bytesSent :: !Int64,
-  inboxDecoder :: Decoder RawMessage,
-  outbox :: Maybe Put,
-  objects :: HashMap GenericObjectId (SomeObject s),
-  nextId :: Word32
-}
-
-
 data Callback s i = forall a. IsInterfaceHandler s i a => Callback a
 
 instance IsInterfaceSide s i => IsInterfaceHandler s i (Callback s i) where
   handleMessage (Callback callback) = handleMessage callback
 
 
-data LowLevelCallback s i = IsInterfaceSide s i => FnCallback (Object s i -> Down s i -> ProtocolAction s ())
+data LowLevelCallback s i = IsInterfaceSide s i => FnCallback (Object s i -> Down s i -> ProtocolM s ())
 
 instance IsInterfaceSide s i => IsInterfaceHandler s i (LowLevelCallback s i) where
   handleMessage (FnCallback fn) object msg = fn object msg
 
-internalFnCallback :: IsInterfaceSide s i => (Object s i -> Down s i -> ProtocolAction s ()) -> Callback s i
+internalFnCallback :: IsInterfaceSide s i => (Object s i -> Down s i -> ProtocolM s ()) -> Callback s i
 internalFnCallback = Callback . FnCallback
 
 
@@ -359,113 +358,158 @@ data ServerError = ServerError Word32 String
   deriving stock Show
   deriving anyclass Exception
 
--- * Monad plumbing
+-- * Protocol state and monad plumbing
 
-type ProtocolStep s a = ProtocolState s -> STM (Either SomeException a, Maybe BSL.ByteString, ProtocolState s)
+-- | Top-level protocol handle (used e.g. to send/receive data)
+newtype ProtocolHandle (s :: Side) = ProtocolHandle {
+  stateVar :: TVar (Either SomeException (ProtocolState s))
+}
 
--- Must not be exported. 'ProtocolStep' ensures proper protocol failure in case of exceptions.
-type ProtocolAction s a = StateT (ProtocolState s) STM a
+-- | Protocol state handle, containing state for a non-failed protocol (should be kept in a 'ProtocolStateVar')
+data ProtocolState (s :: Side) = ProtocolState {
+  bytesReceivedVar :: TVar Int64,
+  bytesSentVar :: TVar Int64,
+  inboxDecoderVar :: TVar (Decoder RawMessage),
+  outboxVar :: TVar (Maybe Put),
+  objectsVar :: TVar (HashMap GenericObjectId (SomeObject s)),
+  nextIdVar :: TVar Word32
+}
 
-protocolStep :: forall s a. ProtocolAction s a -> ProtocolStep s a
-protocolStep action inState = do
-  mapM_ throwM inState.protocolException
-  (result, (outbox, outState)) <- fmap takeOutbox . storeExceptionIfFailed <$> runStateT (try action) inState
-  pure (result, outbox, outState)
-  where
-    storeExceptionIfFailed :: (Either SomeException a, ProtocolState s) -> (Either SomeException a, ProtocolState s)
-    storeExceptionIfFailed (Left ex, st) = (Left ex, setException' ex st)
-    storeExceptionIfFailed x = x
-    setException' :: Exception e => e -> ProtocolState s -> ProtocolState s
-    setException' ex st =
-      if isJust st.protocolException
-        then st
-        else st{protocolException = Just (toException ex)}
+type ProtocolM s a = ReaderT (ProtocolState s) STM a
+
+readProtocolVar :: (ProtocolState s -> TVar a) -> ProtocolM s a
+readProtocolVar fn = do
+  state <- ask
+  lift $ readTVar (fn state)
 
+writeProtocolVar :: (ProtocolState s -> TVar a) -> a -> ProtocolM s ()
+writeProtocolVar fn x = do
+  state <- ask
+  lift $ writeTVar (fn state) x
 
--- * Exported functions
+modifyProtocolVar :: (ProtocolState s -> TVar a) -> (a -> a) -> ProtocolM s ()
+modifyProtocolVar fn x = do
+  state <- ask
+  lift $ modifyTVar (fn state) x
 
-initialProtocolState
-  :: forall wl_display wl_registry s. (IsInterfaceSide s wl_display)
+stateProtocolVar :: (ProtocolState s -> TVar a) -> (a -> (r, a)) -> ProtocolM s r
+stateProtocolVar fn x = do
+  state <- ask
+  lift $ stateTVar (fn state) x
+
+initializeProtocol
+  :: forall s wl_display a. (IsInterfaceSide s wl_display)
   => Callback s wl_display
-  -> (ProtocolState s, Object s wl_display)
-initialProtocolState wlDisplayCallback = (initialState, wlDisplay)
+  -> (Object s wl_display -> ProtocolM s a)
+  -> STM (a, ProtocolHandle s)
+initializeProtocol wlDisplayCallback initializationAction = do
+  bytesReceivedVar <- newTVar 0
+  bytesSentVar <- newTVar 0
+  inboxDecoderVar <- newTVar $ runGetIncremental getRawMessage
+  outboxVar <- newTVar Nothing
+  objectsVar <- newTVar $ HM.fromList [(1, (SomeObject wlDisplay))]
+  nextIdVar <- newTVar (initialId @s)
+  let state = ProtocolState {
+    bytesReceivedVar,
+    bytesSentVar,
+    inboxDecoderVar,
+    outboxVar,
+    objectsVar,
+    nextIdVar
+  }
+  stateVar <- newTVar (Right state)
+  let handle = ProtocolHandle {
+    stateVar
+  }
+  result <- runReaderT (initializationAction wlDisplay) state
+  pure (result, handle)
   where
     wlDisplay :: Object s wl_display
     wlDisplay = Object 1 wlDisplayCallback
-    initialState :: ProtocolState s
-    initialState = ProtocolState {
-      protocolException = Nothing,
-      bytesReceived = 0,
-      bytesSent = 0,
-      inboxDecoder = runGetIncremental getRawMessage,
-      outbox = Nothing,
-      objects = HM.fromList [(1, (SomeObject wlDisplay))],
-      nextId = initialId @s
-    }
-
--- | Feed the protocol newly received data
-feedInput :: IsSide s => ByteString -> ProtocolStep s ()
-feedInput bytes = protocolStep do
-  feed
+
+-- | Entry point to run a protocol action, effectively an 'atomically' with correct error handling.
+--
+-- Throws an exception, when the protocol reaches or is in a failed (/error) state.
+runProtocolM :: (MonadIO m, MonadThrow m) => ProtocolHandle s -> ProtocolM s a -> m a
+runProtocolM (ProtocolHandle stateVar) action = do
+  result <- liftIO $ atomically do
+    readTVar stateVar >>= \case
+      -- Protocol is already in a failed state
+      Left ex -> throwM ex
+      Right state -> do
+        -- Run action, catch exceptions
+        result <- runReaderT (try action) state
+        case result of
+          Left ex -> do
+            -- Action failed, change protocol state to failed
+            writeTVar stateVar (Left ex)
+            pure (Left ex)
+          Right result -> do
+            pure (Right result)
+  -- Transaction is committed, rethrow exception if the action failed
+  either throwM pure result
+
+
+
+-- | Feed the protocol newly received data.
+feedInput :: (IsSide s, MonadIO m, MonadThrow m) => ProtocolHandle s -> ByteString -> m ()
+feedInput protocol bytes = runProtocolM protocol do
+  -- Exposing MonadIO instead of STM to the outside and using `runProtocolM` here enforces correct exception handling.
+  modifyProtocolVar (.bytesReceivedVar) (+ fromIntegral (BS.length bytes))
+  modifyProtocolVar (.inboxDecoderVar) (`pushChunk` bytes)
   receiveMessages
-  where
-    feed = State.modify \st -> st {
-      bytesReceived = st.bytesReceived + fromIntegral (BS.length bytes),
-      inboxDecoder = pushChunk st.inboxDecoder bytes
-    }
 
-setException :: Exception e => e -> ProtocolStep s ()
-setException ex = protocolStep do
-  State.modify \st -> st{protocolException = Just (toException ex)}
+-- | Set the protocol to a failed state, e.g. when the socket closed unexpectedly.
+setException :: (Exception e, MonadIO m, MonadThrow m) => ProtocolHandle s -> e -> m ()
+setException protocol ex = runProtocolM protocol $ throwM ex
+
+-- | Take data that has to be sent. Blocks until data is available.
+takeOutbox :: (MonadIO m, MonadThrow m) => ProtocolHandle s -> m (BSL.ByteString)
+takeOutbox protocol = runProtocolM protocol do
+  mOutboxData <- stateProtocolVar (.outboxVar) (\mOutboxData -> (mOutboxData, Nothing))
+  outboxData <- maybe (lift retry) pure mOutboxData
+  let sendData = runPut outboxData
+  modifyProtocolVar (.bytesSentVar) (+ BSL.length sendData)
+  pure sendData
 
 
 -- | Create an object. The caller is responsible for sending the 'NewId' exactly once before using the object.
 newObject
   :: forall s i. IsInterfaceSide s i
   => Callback s i
-  -> ProtocolStep s (Object s i, NewId (InterfaceName i))
-newObject callback = protocolStep $ newObjectInternal callback
-
-newObjectInternal
-  :: forall s i. IsInterfaceSide s i
-  => Callback s i
-  -> ProtocolAction s (Object s i, NewId (InterfaceName i))
-newObjectInternal callback = do
-  genOId <- allocateObjectId @s
+  -> ProtocolM s (Object s i, NewId (InterfaceName i))
+newObject callback = do
+  genOId <- allocateObjectId
   let oId = NewId @(InterfaceName i) genOId
   object <- newObjectFromId oId callback
   pure (object, oId)
   where
-    allocateObjectId :: forall s. IsSide s => ProtocolAction s GenericObjectId
+    allocateObjectId :: ProtocolM s GenericObjectId
     allocateObjectId = do
-      st <- State.get
-      let
-        id = st.nextId
-        nextId' = id + 1
+      id' <- readProtocolVar (.nextIdVar)
 
+      let nextId' = id' + 1
       when (nextId' == maximumId @s) $ throwM MaximumIdReached
-      State.put $ st {nextId = nextId'}
-      pure id
+
+      writeProtocolVar (.nextIdVar) nextId'
+      pure id'
 
 newObjectFromId
   :: forall s i. IsInterfaceSide s i
   => NewId (InterfaceName i)
   -> Callback s i
-  -> ProtocolAction s (Object s i)
+  -> ProtocolM s (Object s i)
 newObjectFromId (NewId oId) callback = do
   let
     object = Object oId callback
     someObject = SomeObject object
-  State.modify \st -> st { objects = HM.insert oId someObject st.objects}
+  modifyProtocolVar (.objectsVar) (HM.insert oId someObject)
   pure object
 
 
--- | Sends a message without checking any ids or creating proxy objects objects.
-sendMessage :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolStep s ()
-sendMessage object message = protocolStep $ sendMessageInternal object message
-
-sendMessageInternal :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolAction s ()
-sendMessageInternal object message = do
+-- | Sends a message without checking any ids or creating proxy objects objects. (TODO)
+sendMessage :: forall s i. IsInterfaceSide s i => Object s i -> Up s i -> ProtocolM s ()
+sendMessage object message = do
   traceM $ "-> " <> showObjectMessage object message
   sendRawMessage messageWithHeader
   where
@@ -483,25 +527,17 @@ sendMessageInternal object message = do
     msgSizeInteger :: Integer
     msgSizeInteger = 8 + fromIntegral (BSL.length body)
 
--- | Take data that has to be sent (if available)
-takeOutbox :: ProtocolState s ->  (Maybe BSL.ByteString, ProtocolState s)
-takeOutbox st = (maybeOutboxData, st{outbox = Nothing, bytesSent = st.bytesSent + outboxNumBytes})
-  where
-    maybeOutboxData = if isJust st.protocolException then Nothing else outboxData
-    outboxData = runPut <$> st.outbox
-    outboxNumBytes = maybe 0 BSL.length maybeOutboxData
-
 
-receiveMessages :: IsSide s => ProtocolAction s ()
+receiveMessages :: IsSide s => ProtocolM s ()
 receiveMessages = receiveRawMessage >>= \case
   Nothing -> pure ()
   Just rawMessage -> do
     handleRawMessage rawMessage
     receiveMessages
 
-handleRawMessage :: forall s. RawMessage -> ProtocolAction s ()
+handleRawMessage :: forall s. RawMessage -> ProtocolM s ()
 handleRawMessage (oId, opcode, body) = do
-  objects <- State.gets (.objects)
+  objects <- readProtocolVar (.objectsVar)
   case HM.lookup oId objects of
     Nothing -> throwM $ ProtocolException $ "Received message with invalid object id " <> show oId
 
@@ -520,23 +556,22 @@ getMessageAction
   :: IsInterfaceSide s i
   => Object s i
   -> Opcode
-  -> Get (ProtocolAction s ())
+  -> Get (ProtocolM s ())
 getMessageAction object@(Object _ objectHandler) opcode = do
   message <- getDown object opcode
   pure $ handleMessage objectHandler object message
 
 type RawMessage = (GenericObjectId, Opcode, BSL.ByteString)
 
-receiveRawMessage :: forall s. ProtocolAction s (Maybe RawMessage)
+receiveRawMessage :: forall s. ProtocolM s (Maybe RawMessage)
 receiveRawMessage = do
-  st <- State.get
-  (result, newDecoder) <- checkDecoder st.inboxDecoder
-  State.put st{inboxDecoder = newDecoder}
+  (result, nextDecoder) <- checkDecoder =<< readProtocolVar (.inboxDecoderVar)
+  writeProtocolVar (.inboxDecoderVar) nextDecoder
   pure result
   where
     checkDecoder
       :: Decoder RawMessage
-      -> ProtocolAction s (Maybe RawMessage, Decoder RawMessage)
+      -> ProtocolM s (Maybe RawMessage, Decoder RawMessage)
     checkDecoder (Fail _ _ message) = throwM (ParserFailed "RawMessage" message)
     checkDecoder x@(Partial _) = pure (Nothing, x)
     checkDecoder (Done leftovers _ result) = pure (Just result, pushChunk (runGetIncremental getRawMessage) leftovers)
@@ -580,7 +615,5 @@ padding :: Integral a => a -> a
 padding size = ((4 - (size `mod` 4)) `mod` 4)
 
 
-sendRawMessage :: Put -> ProtocolAction s ()
-sendRawMessage x = State.modify \st -> st {
-  outbox = Just (maybe x (<> x) st.outbox)
-}
+sendRawMessage :: Put -> ProtocolM s ()
+sendRawMessage x = modifyProtocolVar (.outboxVar) (Just . maybe x (<> x))