module Quasar.Wayland.Protocol.TH (
  generateWaylandProcol
) where

import Control.Monad.Catch
import Control.Monad.Writer
import Data.Binary
import Data.ByteString qualified as BS
import Language.Haskell.TH
import Language.Haskell.TH.Lib
import Language.Haskell.TH.Syntax (BangType, VarBangType, addDependentFile)
import Language.Haskell.TH.Syntax qualified as TH
import Data.List (intersperse)
import Quasar.Prelude
import Quasar.Wayland.Protocol.Core
import Text.XML.Light


data ProtocolSpec = ProtocolSpec {interfaces :: [InterfaceSpec]}
  deriving stock Show

data InterfaceSpec = InterfaceSpec {
  name :: String,
  version :: Integer,
  requests :: [RequestSpec],
  events :: [EventSpec]
}
  deriving stock Show

newtype RequestSpec = RequestSpec {messageSpec :: MessageSpec}
  deriving stock Show

newtype EventSpec = EventSpec {messageSpec :: MessageSpec}
  deriving stock Show

data MessageSpec = MessageSpec {
  name :: String,
  since :: Maybe Integer,
  opcode :: Opcode,
  arguments :: [ArgumentSpec]
}
  deriving stock Show

data ArgumentSpec = ArgumentSpec {
  name :: String,
  index :: Integer,
  argType :: ArgumentType
}
  deriving stock Show


generateWaylandProcol :: FilePath -> Q [Dec]
generateWaylandProcol protocolFile = do
  addDependentFile protocolFile
  xml <- liftIO (BS.readFile protocolFile)
  protocol <- parseProtocol xml
  (public, internals) <- unzip <$> mapM interfaceDecs protocol.interfaces
  pure $ mconcat public <> mconcat internals


tellQ :: Q a -> WriterT [a] Q ()
tellQ action = tell =<< lift (singleton <$> action)
  where
    -- TODO use from base (base-4.14.0.0)
    singleton :: a -> [a]
    singleton x = [x]

tellQs :: Q [a] -> WriterT [a] Q ()
tellQs = tell <=< lift

interfaceDecs :: InterfaceSpec -> Q ([Dec], [Dec])
interfaceDecs interface = do
  public <- execWriterT do
    tellQ requestClassD
    tellQ eventClassD
  internals <- execWriterT do
    tellQ $ dataD (pure []) iName [] Nothing [normalC iName []] [derivingInterfaceClient, derivingInterfaceServer]
    tellQ $ instanceD (pure []) [t|IsInterface $iT|] instanceDecs

    when (length interface.requests > 0) do
      tellQs $ messageTypeDecs rTypeName requestContexts

    when (length interface.events > 0) do
      tellQs $ messageTypeDecs eTypeName eventContexts

  pure (public, internals)

  where
    iName = interfaceN interface
    iT = interfaceT interface
    instanceDecs = [
      tySynInstD (tySynEqn Nothing (appT (conT ''Request) iT) rT),
      tySynInstD (tySynEqn Nothing (appT (conT ''Event) iT) eT),
      tySynInstD (tySynEqn Nothing (appT (conT ''InterfaceName) iT) (litT (strTyLit interface.name))),
      valD (varP 'interfaceName) (normalB (stringE interface.name)) []
      ]
    rT :: Q Type
    rT = if length interface.requests > 0 then conT rTypeName else [t|Void|]
    rTypeName :: Name
    rTypeName = mkName $ "R_" <> interface.name
    rConName :: RequestSpec -> Name
    rConName (RequestSpec request) = mkName $ "R_" <> interface.name <> "_" <> request.name
    eT :: Q Type
    eT = if length interface.events > 0 then conT eTypeName else [t|Void|]
    eTypeName :: Name
    eTypeName = mkName $ "E_" <> interface.name
    eConName :: EventSpec -> Name
    eConName (EventSpec event) = mkName $ "E_" <> interface.name <> "_" <> event.name
    requestContext :: RequestSpec -> MessageContext
    requestContext req@(RequestSpec msgSpec) = MessageContext {
      msgInterfaceT = iT,
      msgT = rT,
      msgConName = rConName req,
      msgInterfaceSpec = interface,
      msgSpec = msgSpec
    }
    requestContexts = requestContext <$> interface.requests
    eventContext :: EventSpec -> MessageContext
    eventContext ev@(EventSpec msgSpec) = MessageContext {
      msgInterfaceT = iT,
      msgT = eT,
      msgConName = eConName ev,
      msgInterfaceSpec = interface,
      msgSpec = msgSpec
    }
    eventContexts = eventContext <$> interface.events

    aName :: Name
    aName = mkName "a"
    aType :: Q Type
    aType = varT aName
    mName :: Name
    mName = mkName "m"
    mType :: Q Type
    mType = varT mName

    requestClassD :: Q Dec
    requestClassD =
      -- [t|MonadCatch $mType|]
      classD (cxt []) (requestClassN interface) [plainTV mName, plainTV aName] [] (callSigD <$> requestContexts)

    eventClassD :: Q Dec
    eventClassD =
      -- [t|MonadCatch $mType|]
      classD (cxt []) (eventClassN interface) [plainTV mName, plainTV aName] [] (callSigD <$> eventContexts)

    callSigD :: MessageContext -> Q Dec
    callSigD msg = sigD (mkName (interface.name <> "__" <> msg.msgSpec.name)) [t|$aType -> $(applyArgTypes [t|$mType ()|])|]
      where
        applyArgTypes :: Q Type -> Q Type
        applyArgTypes xt = foldr (\x y -> [t|$x -> $y|]) xt (argumentType <$> msg.msgSpec.arguments)

interfaceSideInstanceDs :: InterfaceSpec -> Q [Dec]
interfaceSideInstanceDs interface = execWriterT do
  tellQs [d|instance IsInterfaceSide 'Client $iT|]
  tellQs [d|instance IsInterfaceSide 'Server $iT|]
  --tellQs [d|instance forall m a. IsInterfaceHandler 'Client m $iT a where {handleMessage = undefined}|]
  --tellQs [d|instance forall m a. IsInterfaceHandler 'Server m $iT a where {handleMessage = undefined}|]
  where
    iT = interfaceT interface


interfaceN :: InterfaceSpec -> Name
interfaceN interface = mkName $ "I_" <> interface.name

interfaceT :: InterfaceSpec -> Q Type
interfaceT interface = conT (interfaceN interface)

requestClassN :: InterfaceSpec -> Name
requestClassN interface = mkName $ "Requests_" <> interface.name

requestClassT :: InterfaceSpec -> Q Type
requestClassT interface = conT (requestClassN interface)

eventClassN :: InterfaceSpec -> Name
eventClassN interface = mkName $ "Events_" <> interface.name

eventClassT :: InterfaceSpec -> Q Type
eventClassT interface = conT (eventClassN interface)



data MessageContext = MessageContext {
  msgInterfaceT :: Q Type,
  msgT :: Q Type,
  msgConName :: Name,
  msgInterfaceSpec :: InterfaceSpec,
  msgSpec :: MessageSpec
}

-- | Pattern to match a message. Arguments can then be accessed by using 'msgArgE'.
msgConP :: MessageContext -> Q Pat
msgConP msg = conP msg.msgConName (varP . msgArgTempName <$> msg.msgSpec.arguments)

-- | Expression for accessing a message argument which has been matched from a request/event using 'msgArgConP'.
msgArgE :: MessageContext -> ArgumentSpec -> Q Exp
msgArgE _msg arg = varE (msgArgTempName arg)

-- | Helper for 'msgConP' and 'msgArgE'.
msgArgTempName :: ArgumentSpec -> Name
-- Add an "_" to prevent name conflicts with everything
msgArgTempName arg = mkName $ arg.name <> "_"


messageTypeDecs :: Name -> [MessageContext] -> Q [Dec]
messageTypeDecs name msgs = execWriterT do
  tellQ $ messageTypeD
  tellQ $ isMessageInstanceD t msgs
  tellQ $ showInstanceD
  where
    t :: Q Type
    t = conT name
    messageTypeD :: Q Dec
    messageTypeD = dataD (pure []) name [] Nothing (con <$> msgs) [derivingEq]
    con :: MessageContext -> Q Con
    con msg = normalC (msg.msgConName) (conField <$> msg.msgSpec.arguments)
      where
        conField :: ArgumentSpec -> Q BangType
        conField arg = defaultBangType (argumentType arg)
    showInstanceD :: Q Dec
    showInstanceD = instanceD (pure []) [t|Show $t|] [showD]
    showD :: Q Dec
    showD = funD 'show (showClause <$> msgs)
    showClause :: MessageContext -> Q Clause
    showClause msg =
      clause
        [msgConP msg]
        (normalB [|mconcat $(listE ([stringE (msg.msgSpec.name ++ "(")] <> mconcat (intersperse [stringE ", "] (showArgE <$> msg.msgSpec.arguments) <> [[stringE ")"]])))|])
        []
      where
        showArgE :: ArgumentSpec -> [Q Exp]
        showArgE arg = [stringE (arg.name ++ "="), [|showArgument @($(argumentSpecType arg)) $(msgArgE msg arg)|]]

isMessageInstanceD :: Q Type -> [MessageContext] -> Q Dec
isMessageInstanceD t msgs = instanceD (pure []) [t|IsMessage $t|] [opcodeNameD, getMessageD, putMessageD]
  where
    opcodeNameD :: Q Dec
    opcodeNameD = funD 'opcodeName ((opcodeNameClause <$> msgs) <> [opcodeNameInvalidClause])
    opcodeNameClause :: MessageContext -> Q Clause
    opcodeNameClause msg = clause [litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB ([|Just $(stringE msg.msgSpec.name)|])) []
    opcodeNameInvalidClause :: Q Clause
    opcodeNameInvalidClause = clause [wildP] (normalB ([|Nothing|])) []
    getMessageD :: Q Dec
    getMessageD = funD 'getMessage ((getMessageClause <$> msgs) <> [getMessageInvalidOpcodeClause])
    getMessageClause :: MessageContext -> Q Clause
    getMessageClause msg = clause [wildP, litP (integerL (fromIntegral msg.msgSpec.opcode))] (normalB getMessageE) []
      where
        getMessageE :: Q Exp
        getMessageE = applyA (conE (msg.msgConName)) ((\argT -> [|getArgument @($argT)|]) . argumentSpecType <$> msg.msgSpec.arguments)
    getMessageInvalidOpcodeClause :: Q Clause
    getMessageInvalidOpcodeClause = do
      let object = mkName "object"
      let opcode = mkName "opcode"
      clause [varP object, varP opcode] (normalB [|invalidOpcode $(varE object) $(varE opcode)|]) []
    putMessageD :: Q Dec
    putMessageD = funD 'putMessage (putMessageClauseD <$> msgs)
    putMessageClauseD :: MessageContext -> Q Clause
    putMessageClauseD msg = clause [msgConP msg] (normalB (putMessageE msg.msgSpec.arguments)) []
      where
        putMessageE :: [ArgumentSpec] -> Q Exp
        putMessageE [] = opcodeE
        putMessageE args = doE (((\arg -> noBindS [|putArgument @($(argumentSpecType arg)) $(msgArgE msg arg)|]) <$> args) <> [noBindS opcodeE])
        opcodeE :: Q Exp
        opcodeE = [|pure $(litE $ integerL $ fromIntegral msg.msgSpec.opcode)|]


derivingEq :: Q DerivClause
derivingEq = derivClause (Just StockStrategy) [[t|Eq|]]

derivingShow :: Q DerivClause
derivingShow = derivClause (Just StockStrategy) [[t|Show|]]

derivingInterfaceClient :: Q DerivClause
derivingInterfaceClient = derivClause (Just AnyclassStrategy) [[t|IsInterfaceSide 'Client|]]

derivingInterfaceServer :: Q DerivClause
derivingInterfaceServer = derivClause (Just AnyclassStrategy) [[t|IsInterfaceSide 'Server|]]

argumentType :: ArgumentSpec -> Q Type
argumentType argSpec = [t|Argument $(promoteArgumentSpecType argSpec.argType)|]

argumentSpecType :: ArgumentSpec -> Q Type
argumentSpecType argSpec = promoteArgumentSpecType argSpec.argType

promoteArgumentSpecType :: ArgumentType -> Q Type
promoteArgumentSpecType (ObjectArgument iName) = [t|ObjectId $(litT $ strTyLit iName)|]
promoteArgumentSpecType (NewIdArgument iName) = [t|NewId $(litT $ strTyLit iName)|]
promoteArgumentSpecType arg = do
  argExp <- (TH.lift arg)
  matchCon argExp
  where
    matchCon :: Exp -> Q Type
    matchCon (ConE name) = pure $ ConT name
    matchCon (AppE x _) = matchCon x
    matchCon _ = fail "Can only promote ConE expression"

defaultBangType :: Q Type -> Q BangType
defaultBangType = bangType (bang noSourceUnpackedness noSourceStrictness)


-- | (a -> b -> c -> d) -> [m a, m b, m c] -> m d
applyA :: Q Exp -> [Q Exp] -> Q Exp
applyA con [] = [|pure $con|]
applyA con (monadicE:monadicEs) = foldl (\x y -> [|$x <*> $y|]) [|$con <$> $monadicE|] monadicEs

-- | (a -> b -> c -> m d) -> [m a, m b, m c] -> m d
applyM :: Q Exp -> [Q Exp] -> Q Exp
applyM con [] = con
applyM con args = [|join $(applyA con args)|]


-- * XML parser

parseProtocol :: MonadFail m => BS.ByteString -> m ProtocolSpec
parseProtocol xml = do
  (Just element) <- pure $ parseXMLDoc xml
  interfaces <- mapM parseInterface $ findChildren (blank_name { qName = "interface" }) element
  pure ProtocolSpec {
    interfaces
  }

parseInterface :: MonadFail m => Element -> m InterfaceSpec
parseInterface element = do
  name <- getAttr "name" element
  version <- read <$> getAttr "version" element
  requests <- mapM (parseRequest name) $ zip [0..] $ findChildren (qname "request") element
  events <- mapM (parseEvent name) $ zip [0..] $ findChildren (qname "event") element
  pure InterfaceSpec {
    name,
    version,
    requests,
    events
  }

parseRequest :: MonadFail m => String -> (Opcode, Element) -> m RequestSpec
parseRequest x y = RequestSpec <$> parseMessage x y

parseEvent :: MonadFail m => String -> (Opcode, Element) -> m EventSpec
parseEvent x y = EventSpec <$> parseMessage x y

parseMessage :: MonadFail m => String -> (Opcode, Element) -> m MessageSpec
parseMessage interfaceName (opcode, element) = do
  name <- getAttr "name" element
  since <- read <<$>> peekAttr "since" element
  arguments <- mapM parseArgument $ zip [0..] $ findChildren (qname "arg") element
  forM_ arguments \arg -> do
    when
      do arg.argType == GenericNewIdArgument && (interfaceName /= "wl_registry" || name /= "bind")
      do fail $ "Invalid 'new_id' argument without 'interface' attribute encountered on " <> interfaceName <> "." <> name <> " (only valid on wl_registry.bind)"
    when
      do arg.argType == GenericObjectArgument && (interfaceName /= "wl_display" || name /= "error")
      do fail $ "Invalid 'object' argument without 'interface' attribute encountered on " <> interfaceName <> "." <> name <> " (only valid on wl_display.error)"
  pure MessageSpec  {
    name,
    since,
    opcode,
    arguments
  }


parseArgument :: forall m. MonadFail m => (Integer, Element) -> m ArgumentSpec
parseArgument (index, element) = do
  name <- getAttr "name" element
  argTypeStr <- getAttr "type" element
  interface <- peekAttr "interface" element
  argType <- parseArgumentType argTypeStr interface
  pure ArgumentSpec {
    name,
    index,
    argType
  }
  where
    parseArgumentType :: String -> Maybe String -> m ArgumentType
    parseArgumentType "int" Nothing = pure IntArgument
    parseArgumentType "uint" Nothing = pure UIntArgument
    parseArgumentType "fixed" Nothing = pure FixedArgument
    parseArgumentType "string" Nothing = pure StringArgument
    parseArgumentType "array" Nothing = pure ArrayArgument
    parseArgumentType "object" (Just interface) = pure (ObjectArgument interface)
    parseArgumentType "object" Nothing = pure GenericObjectArgument
    parseArgumentType "new_id" (Just interface) = pure (NewIdArgument interface)
    parseArgumentType "new_id" Nothing = pure GenericNewIdArgument
    parseArgumentType "fd" Nothing = pure FdArgument
    parseArgumentType x Nothing = fail $ "Unknown argument type \"" <> x <> "\" encountered"
    parseArgumentType x _ = fail $ "Argument type \"" <> x <> "\" should not have \"interface\" attribute"


qname :: String -> QName
qname name = blank_name { qName = name }

getAttr :: MonadFail m => String -> Element -> m String
getAttr name element = do
  (Just value) <- pure $ findAttr (qname name) element
  pure value

peekAttr :: Applicative m => String -> Element -> m (Maybe String)
peekAttr name element = pure $ findAttr (qname name) element