From a2bf3109d1df8b755dee8f6d55b6f95dafc689cb Mon Sep 17 00:00:00 2001 From: Jens Nolte <git@queezle.net> Date: Tue, 14 Sep 2021 23:11:27 +0200 Subject: [PATCH] Fix invalid opcode handling --- src/Quasar/Wayland/Protocol/Core.hs | 1 + src/Quasar/Wayland/Protocol/TH.hs | 19 +++++++++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs index 762b5f1..e983781 100644 --- a/src/Quasar/Wayland/Protocol/Core.hs +++ b/src/Quasar/Wayland/Protocol/Core.hs @@ -32,6 +32,7 @@ module Quasar.Wayland.Protocol.Core ( -- * Message decoder operations WireFormat(..), dropRemaining, + invalidOpcode, ) where import Control.Monad (replicateM_) diff --git a/src/Quasar/Wayland/Protocol/TH.hs b/src/Quasar/Wayland/Protocol/TH.hs index 92393ea..dab855a 100644 --- a/src/Quasar/Wayland/Protocol/TH.hs +++ b/src/Quasar/Wayland/Protocol/TH.hs @@ -224,16 +224,23 @@ isMessageInstanceD :: Q Type -> [MessageContext] -> Q Dec isMessageInstanceD t msgs = instanceD (pure []) [t|IsMessage $t|] [opcodeNameD, getMessageD, putMessageD] where opcodeNameD :: Q Dec - opcodeNameD = funD 'opcodeName (opcodeNameClauseD <$> msgs) - opcodeNameClauseD :: MessageContext -> Q Clause - opcodeNameClauseD msg = clause [litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB ([|Just $(stringE msg.msgSpec.name)|])) [] + opcodeNameD = funD 'opcodeName ((opcodeNameClause <$> msgs) <> [opcodeNameInvalidClause]) + opcodeNameClause :: MessageContext -> Q Clause + opcodeNameClause msg = clause [litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB ([|Just $(stringE msg.msgSpec.name)|])) [] + opcodeNameInvalidClause :: Q Clause + opcodeNameInvalidClause = clause [wildP] (normalB ([|Nothing|])) [] getMessageD :: Q Dec - getMessageD = funD 'getMessage (getMessageClauseD <$> msgs) - getMessageClauseD :: MessageContext -> Q Clause - getMessageClauseD msg = clause [wildP, litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB getMessageE) [] + getMessageD = funD 'getMessage ((getMessageClause <$> msgs) <> [getMessageInvalidOpcodeClause]) + getMessageClause :: MessageContext -> Q Clause + getMessageClause msg = clause [wildP, litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB getMessageE) [] where getMessageE :: Q Exp getMessageE = applyA (conE (msg.msgConName)) ((\argT -> [|getArgument @($argT)|]) . argumentSpecType <$> msg.msgSpec.arguments) + getMessageInvalidOpcodeClause :: Q Clause + getMessageInvalidOpcodeClause = do + let object = mkName "object" + let opcode = mkName "opcode" + clause [varP object, varP opcode] (normalB [|invalidOpcode $(varE object) $(varE opcode)|]) [] putMessageD :: Q Dec putMessageD = funD 'putMessage (putMessageClauseD <$> msgs) putMessageClauseD :: MessageContext -> Q Clause -- GitLab