From a2bf3109d1df8b755dee8f6d55b6f95dafc689cb Mon Sep 17 00:00:00 2001
From: Jens Nolte <git@queezle.net>
Date: Tue, 14 Sep 2021 23:11:27 +0200
Subject: [PATCH] Fix invalid opcode handling

---
 src/Quasar/Wayland/Protocol/Core.hs |  1 +
 src/Quasar/Wayland/Protocol/TH.hs   | 19 +++++++++++++------
 2 files changed, 14 insertions(+), 6 deletions(-)

diff --git a/src/Quasar/Wayland/Protocol/Core.hs b/src/Quasar/Wayland/Protocol/Core.hs
index 762b5f1..e983781 100644
--- a/src/Quasar/Wayland/Protocol/Core.hs
+++ b/src/Quasar/Wayland/Protocol/Core.hs
@@ -32,6 +32,7 @@ module Quasar.Wayland.Protocol.Core (
   -- * Message decoder operations
   WireFormat(..),
   dropRemaining,
+  invalidOpcode,
 ) where
 
 import Control.Monad (replicateM_)
diff --git a/src/Quasar/Wayland/Protocol/TH.hs b/src/Quasar/Wayland/Protocol/TH.hs
index 92393ea..dab855a 100644
--- a/src/Quasar/Wayland/Protocol/TH.hs
+++ b/src/Quasar/Wayland/Protocol/TH.hs
@@ -224,16 +224,23 @@ isMessageInstanceD :: Q Type -> [MessageContext] -> Q Dec
 isMessageInstanceD t msgs = instanceD (pure []) [t|IsMessage $t|] [opcodeNameD, getMessageD, putMessageD]
   where
     opcodeNameD :: Q Dec
-    opcodeNameD = funD 'opcodeName (opcodeNameClauseD <$> msgs)
-    opcodeNameClauseD :: MessageContext -> Q Clause
-    opcodeNameClauseD msg = clause [litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB ([|Just $(stringE msg.msgSpec.name)|])) []
+    opcodeNameD = funD 'opcodeName ((opcodeNameClause <$> msgs) <> [opcodeNameInvalidClause])
+    opcodeNameClause :: MessageContext -> Q Clause
+    opcodeNameClause msg = clause [litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB ([|Just $(stringE msg.msgSpec.name)|])) []
+    opcodeNameInvalidClause :: Q Clause
+    opcodeNameInvalidClause = clause [wildP] (normalB ([|Nothing|])) []
     getMessageD :: Q Dec
-    getMessageD = funD 'getMessage (getMessageClauseD <$> msgs)
-    getMessageClauseD :: MessageContext -> Q Clause
-    getMessageClauseD msg = clause [wildP, litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB getMessageE) []
+    getMessageD = funD 'getMessage ((getMessageClause <$> msgs) <> [getMessageInvalidOpcodeClause])
+    getMessageClause :: MessageContext -> Q Clause
+    getMessageClause msg = clause [wildP, litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB getMessageE) []
       where
         getMessageE :: Q Exp
         getMessageE = applyA (conE (msg.msgConName)) ((\argT -> [|getArgument @($argT)|]) . argumentSpecType <$> msg.msgSpec.arguments)
+    getMessageInvalidOpcodeClause :: Q Clause
+    getMessageInvalidOpcodeClause = do
+      let object = mkName "object"
+      let opcode = mkName "opcode"
+      clause [varP object, varP opcode] (normalB [|invalidOpcode $(varE object) $(varE opcode)|]) []
     putMessageD :: Q Dec
     putMessageD = funD 'putMessage (putMessageClauseD <$> msgs)
     putMessageClauseD :: MessageContext -> Q Clause
-- 
GitLab