Skip to content
Snippets Groups Projects
Commit 91448cb2 authored by Jens Nolte's avatar Jens Nolte
Browse files

Implement incremental channel message decoder


Co-authored-by: default avatarJan Beinke <git@janbeinke.com>
parent a2e85dc2
No related branches found
No related tags found
No related merge requests found
...@@ -10,7 +10,7 @@ import qualified Control.Monad.State as State ...@@ -10,7 +10,7 @@ import qualified Control.Monad.State as State
import Control.Concurrent.MVar import Control.Concurrent.MVar
import Data.Binary (Binary, encode, decodeOrFail) import Data.Binary (Binary, encode, decodeOrFail)
import qualified Data.Binary as Binary import qualified Data.Binary as Binary
import Data.Binary.Get (Decoder(..), runGetIncremental, pushChunk) import Data.Binary.Get (Decoder(..), runGetIncremental, pushChunk, pushEndOfInput)
import qualified Data.ByteString as BS import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL import qualified Data.ByteString.Lazy as BSL
import Data.Hashable (Hashable) import Data.Hashable (Hashable)
...@@ -337,24 +337,54 @@ metaProtocolReceive worker = receiveThreadLoop metaDecoder ...@@ -337,24 +337,54 @@ metaProtocolReceive worker = receiveThreadLoop metaDecoder
handleMetaMessage (ChannelMessage headers len) = do handleMetaMessage (ChannelMessage headers len) = do
workerState <- liftIO $ readMVar worker.stateMVar workerState <- liftIO $ readMVar worker.stateMVar
case HM.lookup workerState.receiveChannel workerState.channels of case HM.lookup workerState.receiveChannel workerState.channels of
Just channel -> do Just channel -> handleChannelMessage channel headers len
-- StateT currently contains leftovers
(rawMsg, leftovers) <- liftIO . readRawMessage . BSL.fromStrict =<< get
-- Data is received in chunks but messages have a defined length, so leftovers are put back into StateT
put $ BSL.toStrict leftovers
headerResults <- liftIO $ sequence (processHeader <$> headers)
-- Critical section: don't interrupt downstream callbacks
liftIO $ withMVar worker.killReceiverMVar $ \_ -> channelHandleMessage channel headerResults rawMsg
Nothing -> liftIO $ reportProtocolError worker ("Received message on invalid channel: " <> show workerState.receiveChannel) Nothing -> liftIO $ reportProtocolError worker ("Received message on invalid channel: " <> show workerState.receiveChannel)
handleMultiplexerMessage (SwitchChannel channelId) = liftIO $ modifyMVar_ worker.stateMVar $ \state -> pure state{receiveChannel=channelId}
handleMultiplexerMessage x = liftIO $ print x >> undefined -- Unhandled meta message
handleChannelMessage :: Channel -> [MetaProtocolMessageHeader] -> MessageLength -> StateT BS.ByteString IO ()
handleChannelMessage channel headers len = do
headerResults <- liftIO $ sequence (processHeader <$> headers)
decoder <- liftIO $ channelStartHandleMessage channel headerResults
-- StateT currently contains leftovers
initialLeftovers <- get
let
leftoversLength = fromIntegral $ BS.length initialLeftovers
remaining = len - leftoversLength
(channelCallback, leftovers) <- liftIO $ runDecoder remaining (pushChunk decoder initialLeftovers)
-- Data is received in chunks but messages have a defined length, so leftovers are put back into StateT
put leftovers
-- Critical section: don't interrupt downstream callbacks
liftIO $ withMVar worker.killReceiverMVar $ const channelCallback
where where
readRawMessage :: BSL.ByteString -> IO (BSL.ByteString, BSL.ByteString) runDecoder :: MessageLength -> Decoder (IO ()) -> IO (IO (), BS.ByteString)
readRawMessage x runDecoder _ (Fail _ _ err) = failedToParseMessage err
| fromIntegral (BSL.length x) >= len = return $ BSL.splitAt (fromIntegral len) x runDecoder 0 (Partial feedFn) = finalizeDecoder "" (feedFn Nothing)
| otherwise = readRawMessage . BSL.append x . BSL.fromStrict =<< receiveThrowing runDecoder remaining (Partial feedFn) = do
chunk <- receiveThrowing
let chunkLength = fromIntegral $ BS.length chunk
if chunkLength <= remaining
then runDecoder (remaining - chunkLength) (feedFn (Just chunk))
else do
let (partialChunk, leftovers) = BS.splitAt (fromIntegral remaining) chunk
finalizeDecoder leftovers $ pushEndOfInput $ feedFn $ Just partialChunk
runDecoder 0 decoder@Done{} = finalizeDecoder "" decoder
runDecoder _ (Done _ bytesRead _) = failedToConsumeAllInput (fromIntegral bytesRead)
finalizeDecoder :: BS.ByteString -> Decoder (IO ()) -> IO (IO (), BS.ByteString)
finalizeDecoder _ (Fail _ _ err) = failedToParseMessage err
finalizeDecoder _ (Partial _) = failedToTerminate
finalizeDecoder leftovers (Done "" _ result) = return (result, leftovers)
finalizeDecoder _ (Done _ bytesRead _) = failedToConsumeAllInput (fromIntegral bytesRead)
failedToParseMessage :: String -> IO a
failedToParseMessage err = reportProtocolError worker ("Failed to parse message on channel " <> show channel.channelId <> ": " <> err)
failedToConsumeAllInput :: MessageLength -> IO a
failedToConsumeAllInput bytesRead = reportProtocolError worker ("Decoder for channel " <> show channel.channelId <> " failed to consume all input (" <> show (len - bytesRead) <> " bytes left)")
failedToTerminate :: IO a
failedToTerminate = reportLocalError worker ("Decoder on channel " <> show channel.channelId <> " failed to terminate after end-of-input")
processHeader :: MetaProtocolMessageHeader -> IO MessageHeaderResult processHeader :: MetaProtocolMessageHeader -> IO MessageHeaderResult
processHeader CreateChannel = undefined processHeader CreateChannel = undefined
handleMetaMessage (SwitchChannel channelId) = liftIO $ modifyMVar_ worker.stateMVar $ \state -> return state{receiveChannel=channelId}
handleMetaMessage x = liftIO $ print x >> undefined -- Unhandled meta message
receiveThrowing :: IO BS.ByteString receiveThrowing :: IO BS.ByteString
receiveThrowing = do receiveThrowing = do
state <- readMVar worker.stateMVar state <- readMVar worker.stateMVar
...@@ -443,7 +473,20 @@ data ChannelReceiveState = ChannelReceiveState { ...@@ -443,7 +473,20 @@ data ChannelReceiveState = ChannelReceiveState {
nextMessageId :: MessageId, nextMessageId :: MessageId,
handler :: ChannelMessageHandler handler :: ChannelMessageHandler
} }
type ChannelMessageHandler = MessageId -> [MessageHeaderResult] -> BSL.ByteString -> IO () type SimpleChannelMessageHandler = MessageId -> [MessageHeaderResult] -> BSL.ByteString -> IO ()
type ChannelMessageHandler = MessageId -> [MessageHeaderResult] -> Decoder (IO ())
simpleMessageHandler :: SimpleChannelMessageHandler -> ChannelMessageHandler
simpleMessageHandler handler msgId headers = decoder ""
where
decoder :: BSL.ByteString -> Decoder (IO ())
decoder acc = Partial (maybe done partial)
where
partial :: BS.ByteString -> Decoder (IO ())
partial = decoder . (acc <>) . BSL.fromStrict
done :: Decoder (IO ())
done = Done "" (BSL.length acc) (handler msgId headers acc)
-- Should not be exported -- Should not be exported
newChannel :: MetaProtocolWorker -> ChannelId -> IO Channel newChannel :: MetaProtocolWorker -> ChannelId -> IO Channel
...@@ -451,7 +494,7 @@ newChannel worker channelId = do ...@@ -451,7 +494,7 @@ newChannel worker channelId = do
sendStateMVar <- newMVar ChannelSendState { sendStateMVar <- newMVar ChannelSendState {
nextMessageId = 0 nextMessageId = 0
} }
let handler = (\_ _ _ -> reportLocalError worker ("Channel " <> show channelId <> ": Received message but no Handler is registered")) let handler = (simpleMessageHandler $ \_ _ _ -> reportLocalError worker ("Channel " <> show channelId <> ": Received message but no Handler is registered"))
receiveStateMVar <- newMVar ChannelReceiveState { receiveStateMVar <- newMVar ChannelReceiveState {
nextMessageId = 0, nextMessageId = 0,
handler handler
...@@ -474,10 +517,11 @@ channelSend_ :: Channel -> BSL.ByteString -> [MessageHeader] -> IO () ...@@ -474,10 +517,11 @@ channelSend_ :: Channel -> BSL.ByteString -> [MessageHeader] -> IO ()
channelSend_ channel msg headers = channelSend channel msg headers (const (return ())) channelSend_ channel msg headers = channelSend channel msg headers (const (return ()))
channelClose :: Channel -> IO () channelClose :: Channel -> IO ()
channelClose channel = metaChannelClose channel.worker channel.channelId channelClose channel = metaChannelClose channel.worker channel.channelId
channelHandleMessage :: Channel -> [MessageHeaderResult] -> BSL.ByteString -> IO () channelStartHandleMessage :: Channel -> [MessageHeaderResult] -> IO (Decoder (IO ()))
channelHandleMessage channel headers msg = modifyMVar_ channel.receiveStateMVar $ \state -> do channelStartHandleMessage channel headers = do
state.handler state.nextMessageId headers msg (msgId, handler) <- modifyMVar channel.receiveStateMVar $ \state ->
return state{nextMessageId = state.nextMessageId + 1} return (state{nextMessageId = state.nextMessageId + 1}, (state.nextMessageId, state.handler))
return (handler msgId headers)
channelSetHandler :: Channel -> ChannelMessageHandler -> IO () channelSetHandler :: Channel -> ChannelMessageHandler -> IO ()
channelSetHandler channel handler = modifyMVar_ channel.receiveStateMVar $ \state -> return state{handler} channelSetHandler channel handler = modifyMVar_ channel.receiveStateMVar $ \state -> return state{handler}
...@@ -552,11 +596,7 @@ serverHandleChannelMessage protocolImpl channel msgId headers msg = case decodeO ...@@ -552,11 +596,7 @@ serverHandleChannelMessage protocolImpl channel msgId headers msg = case decodeO
wrappedResponse = (msgId, response) wrappedResponse = (msgId, response)
registerChannelServerHandler :: forall p. (RpcProtocol p, HasProtocolImpl p) => ProtocolImpl p -> Channel -> IO () registerChannelServerHandler :: forall p. (RpcProtocol p, HasProtocolImpl p) => ProtocolImpl p -> Channel -> IO ()
registerChannelServerHandler protocolImpl channel = channelSetHandler channel (serverHandleChannelMessage @p protocolImpl channel) registerChannelServerHandler protocolImpl channel = channelSetHandler channel (simpleMessageHandler (serverHandleChannelMessage @p protocolImpl channel))
data Future a
data Sink a
data Source a
-- ** Running client and server -- ** Running client and server
...@@ -648,7 +688,7 @@ newChannelClient channel = do ...@@ -648,7 +688,7 @@ newChannelClient channel = do
channel, channel,
stateMVar stateMVar
} }
channelSetHandler channel (clientHandleChannelMessage client) channelSetHandler channel (simpleMessageHandler (clientHandleChannelMessage client))
return client return client
listenTCP :: forall p. (RpcProtocol p, HasProtocolImpl p) => ProtocolImpl p -> Maybe Socket.HostName -> Socket.ServiceName -> IO () listenTCP :: forall p. (RpcProtocol p, HasProtocolImpl p) => ProtocolImpl p -> Maybe Socket.HostName -> Socket.ServiceName -> IO ()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment