{-# LANGUAGE ViewPatterns #-}

module Quasar.Observable (
  -- * Observable core types
  IsRetrievable(..),
  retrieveIO,
  IsObservable(..),
  Observable(..),
  ObservableMessage(..),

  -- * ObservableVar
  ObservableVar,
  newObservableVar,
  setObservableVar,
  withObservableVar,
  modifyObservableVar,
  modifyObservableVar_,

  -- * Helper functions
  observeWhile,
  observeWhile_,
  fnObservable,
  synchronousFnObservable,
  mergeObservable,
  joinObservable,
  bindObservable,
  unsafeObservableIO,

  -- * Helper types
  ObservableCallback,
) where

import Control.Applicative
import Control.Concurrent.MVar
import Control.Concurrent.STM
import Control.Monad.Catch
import Control.Monad.Except
import Control.Monad.Trans.Maybe
import Data.HashMap.Strict qualified as HM
import Data.Unique
import Quasar.Awaitable
import Quasar.Core
import Quasar.Prelude


data ObservableMessage a
  = ObservableUpdate a
  | ObservableLoading
  | ObservableNotAvailable SomeException
  deriving stock (Show, Generic)

instance Functor ObservableMessage where
  fmap fn (ObservableUpdate x) = ObservableUpdate (fn x)
  fmap _ ObservableLoading = ObservableLoading
  fmap _ (ObservableNotAvailable ex) = ObservableNotAvailable ex

instance Applicative ObservableMessage where
  pure = ObservableUpdate
  liftA2 _ (ObservableNotAvailable ex) _ = ObservableNotAvailable ex
  liftA2 _ _ (ObservableNotAvailable ex) = ObservableNotAvailable ex
  liftA2 _ ObservableLoading _ = ObservableLoading
  liftA2 _ _ ObservableLoading = ObservableLoading
  liftA2 fn (ObservableUpdate x) (ObservableUpdate y) = ObservableUpdate (fn x y)


class IsRetrievable v a | a -> v where
  retrieve :: MonadAsync m => a -> m (Task v)

retrieveIO :: IsRetrievable v a => a -> IO v
retrieveIO x = awaitIO =<< withDefaultAsyncManager (retrieve x)

class IsRetrievable v o => IsObservable v o | o -> v where
  observe :: o -> (ObservableMessage v -> IO ()) -> IO Disposable

  toObservable :: o -> Observable v
  toObservable = Observable

  mapObservable :: (v -> a) -> o -> Observable a
  mapObservable f = Observable . MappedObservable f

-- | Observe until the callback returns `False`. The callback will also be unsubscribed when the `AsyncManager` is disposed.
observeWhile :: (IsObservable v o, MonadAsync m) => o -> (ObservableMessage v -> IO Bool) -> m Disposable
observeWhile observable callback = do
  --disposeVar <- liftIO $ newTVarIO False

  --innerDisposable <- liftIO $ observe observable \msg -> do
  --  disposeRequested <- readTVarIO disposeVar
  --  unless disposeRequested do
  --    continue <- callback msg
  --    unless continue $ atomically $ writeTVar disposeVar True

  ---- Bind the disposable to the ResourceManager, to prevent leaks if the `async` is disposed
  --disposable <- boundDisposable $ dispose innerDisposable

  --task <- async do
  --  liftIO $ atomically do
  --    disposeRequested <- readTVar disposeVar
  --    unless disposeRequested retry
  --  liftIO $ dispose disposable

  --pure (disposable <> (toDisposable task))

  undefined -- TODO reimplement after ResouceManager API is changed


-- | Observe until the callback returns `False`. The callback will also be unsubscribed when the `ResourceManager` is disposed.
observeWhile_ :: (IsObservable v o, MonadAsync m) => o -> (ObservableMessage v -> IO Bool) -> m ()
observeWhile_ observable callback =
  -- The disposable is already attached to the resource manager, so voiding it is safe.
  void $ observeWhile observable callback


type ObservableCallback v = ObservableMessage v -> IO ()


-- | Existential quantification wrapper for the IsObservable type class.
data Observable v = forall o. IsObservable v o => Observable o
instance IsRetrievable v (Observable v) where
  retrieve (Observable o) = retrieve o
instance IsObservable v (Observable v) where
  observe (Observable o) = observe o
  toObservable = id
  mapObservable f (Observable o) = mapObservable f o

instance Functor Observable where
  fmap f = mapObservable f

instance Applicative Observable where
  pure = toObservable . ConstObservable
  liftA2 fn x y = toObservable $ MergedObservable fn x y

instance Monad Observable where
  x >>= y = toObservable $ BindObservable x y

instance MonadThrow Observable where
  throwM :: forall e v. Exception e => e -> Observable v
  throwM = toObservable . FailedObservable @v . toException

instance MonadCatch Observable where
  catch action handler = toObservable $ CatchObservable action handler

instance MonadFail Observable where
  fail = throwM . userError

instance Alternative Observable where
  empty = fail "empty"
  x <|> y = x `catchAll` const y

instance MonadPlus Observable



data MappedObservable b = forall a o. IsObservable a o => MappedObservable (a -> b) o
instance IsRetrievable v (MappedObservable v) where
  retrieve (MappedObservable f observable) = f <<$>> retrieve observable
instance IsObservable v (MappedObservable v) where
  observe (MappedObservable f observable) callback = observe observable (callback . fmap f)
  mapObservable f1 (MappedObservable f2 upstream) = Observable $ MappedObservable (f1 . f2) upstream



data BindObservable r = forall a. BindObservable (Observable a) (a -> Observable r)

instance IsRetrievable r (BindObservable r) where
  retrieve (BindObservable fx fn) = async $ do
    x <- awaitResult $ retrieve fx
    awaitResult $ retrieve $ fn x

instance IsObservable r (BindObservable r) where
  observe :: BindObservable r -> (ObservableMessage r -> IO ()) -> IO Disposable
  observe (BindObservable fx fn) callback = do
    -- Create a resource manager to ensure all subscriptions are cleaned up when disposing.
    resourceManager <- newResourceManager

    isDisposingVar <- newTVarIO False
    disposableVar <- newTMVarIO noDisposable
    keyVar <- newTMVarIO Nothing

    leftDisposable <- observe fx (outerCallback resourceManager isDisposingVar disposableVar keyVar)

    attachDisposeAction_ resourceManager $ do
      atomically $ writeTVar isDisposingVar True
      d1 <- dispose leftDisposable
      -- Block while the `outerCallback` is running
      d2 <- dispose =<< atomically (takeTMVar disposableVar)
      pure (d1 <> d2)

    pure $ toDisposable resourceManager
    where
      outerCallback resourceManager isDisposingVar disposableVar keyVar observableMessage = mask $ \unmask -> do
        key <- newUnique

        join $ atomically $ do
          readTVar isDisposingVar >>= \case
            False -> do
              -- Blocks while an inner callback is running
              void $ swapTMVar keyVar (Just key)

              oldDisposable <- takeTMVar disposableVar

              -- IO action that will run after the STM transaction
              pure $ do
                disposeEventually resourceManager oldDisposable

                disposable <-
                  unmask (outerMessageHandler key observableMessage)
                    `onException`
                      atomically (putTMVar disposableVar noDisposable)

                atomically $ putTMVar disposableVar disposable

            -- When already disposing no new handlers should be registered
            True -> pure $ pure ()

        where
          outerMessageHandler key (ObservableUpdate x) = observe (fn x) (innerCallback key)
          outerMessageHandler _ ObservableLoading = noDisposable <$ callback ObservableLoading
          outerMessageHandler _ (ObservableNotAvailable ex) = noDisposable <$ callback (ObservableNotAvailable ex)

          innerCallback :: Unique -> ObservableMessage r -> IO ()
          innerCallback key x = do
            bracket
              -- Take key var to prevent parallel callbacks
              (atomically $ takeTMVar keyVar)
              -- Put key back
              (atomically . putTMVar keyVar)
              -- Call callback when key is still valid
              (\currentKey -> when (Just key == currentKey) $ callback x)


data CatchObservable e r = Exception e => CatchObservable (Observable r) (e -> Observable r)

instance IsRetrievable r (CatchObservable e r) where
  retrieve (CatchObservable fx fn) = async $
    awaitResult (retrieve fx) `catch` \ex -> awaitResult (retrieve (fn ex))

instance IsObservable r (CatchObservable e r) where
  observe :: CatchObservable e r -> (ObservableMessage r -> IO ()) -> IO Disposable
  observe (CatchObservable fx fn) callback = do
    -- Create a resource manager to ensure all subscriptions are cleaned up when disposing.
    resourceManager <- newResourceManager

    isDisposingVar <- newTVarIO False
    disposableVar <- newTMVarIO noDisposable
    keyVar <- newTMVarIO Nothing

    leftDisposable <- observe fx (outerCallback resourceManager isDisposingVar disposableVar keyVar)

    attachDisposeAction_ resourceManager $ do
      atomically $ writeTVar isDisposingVar True
      d1 <- dispose leftDisposable
      -- Block while the `outerCallback` is running
      d2 <- dispose =<< atomically (takeTMVar disposableVar)
      pure (d1 <> d2)

    pure $ toDisposable resourceManager
    where
      outerCallback resourceManager isDisposingVar disposableVar keyVar observableMessage = mask $ \unmask -> do
        key <- newUnique

        join $ atomically $ do
          readTVar isDisposingVar >>= \case
            False -> do
              -- Blocks while an inner callback is running
              void $ swapTMVar keyVar (Just key)

              oldDisposable <- takeTMVar disposableVar

              -- IO action that will run after the STM transaction
              pure $ do
                disposeEventually resourceManager oldDisposable

                disposable <-
                  unmask (outerMessageHandler key observableMessage)
                    `onException`
                      atomically (putTMVar disposableVar noDisposable)

                atomically $ putTMVar disposableVar disposable

            -- When already disposing no new handlers should be registered
            True -> pure $ pure ()

        where
          outerMessageHandler key (ObservableNotAvailable (fromException -> Just ex)) = observe (fn ex) (innerCallback key)
          outerMessageHandler _ msg = noDisposable <$ callback msg

          innerCallback :: Unique -> ObservableMessage r -> IO ()
          innerCallback key x = do
            bracket
              -- Take key var to prevent parallel callbacks
              (atomically $ takeTMVar keyVar)
              -- Put key back
              (atomically . putTMVar keyVar)
              -- Call callback when key is still valid
              (\currentKey -> when (Just key == currentKey) $ callback x)



newtype ObservableVar v = ObservableVar (MVar (v, HM.HashMap Unique (ObservableCallback v)))
instance IsRetrievable v (ObservableVar v) where
  retrieve (ObservableVar mvar) = liftIO $ successfulTask . fst <$> readMVar mvar
instance IsObservable v (ObservableVar v) where
  observe (ObservableVar mvar) callback = do
    key <- newUnique
    modifyMVar_ mvar $ \(state, subscribers) -> do
      -- Call listener
      callback (pure state)
      pure (state, HM.insert key callback subscribers)
    synchronousDisposable (disposeFn key)
    where
      disposeFn :: Unique -> IO ()
      disposeFn key = modifyMVar_ mvar (\(state, subscribers) -> pure (state, HM.delete key subscribers))

newObservableVar :: v -> IO (ObservableVar v)
newObservableVar initialValue = do
  ObservableVar <$> newMVar (initialValue, HM.empty)

setObservableVar :: ObservableVar v -> v -> IO ()
setObservableVar (ObservableVar mvar) value = modifyMVar_ mvar $ \(_, subscribers) -> do
  mapM_ (\callback -> callback (pure value)) subscribers
  pure (value, subscribers)


modifyObservableVar :: ObservableVar v -> (v -> IO (v, a)) -> IO a
modifyObservableVar (ObservableVar mvar) f =
  modifyMVar mvar $ \(oldState, subscribers) -> do
    (newState, result) <- f oldState
    mapM_ (\callback -> callback (pure newState)) subscribers
    pure ((newState, subscribers), result)

modifyObservableVar_ :: ObservableVar v -> (v -> IO v) -> IO ()
modifyObservableVar_ (ObservableVar mvar) f =
  modifyMVar_ mvar $ \(oldState, subscribers) -> do
    newState <- f oldState
    mapM_ (\callback -> callback (pure newState)) subscribers
    pure (newState, subscribers)

withObservableVar :: ObservableVar v -> (v -> IO a) -> IO a
withObservableVar (ObservableVar mvar) f = withMVar mvar (f . fst)



bindObservable :: (IsObservable a ma, IsObservable b mb) => ma -> (a -> mb) -> Observable b
bindObservable fx fn = (toObservable fx) >>= \x -> toObservable (fn x)

joinObservable :: (IsObservable i o, IsObservable v i) => o -> Observable v
joinObservable = join . fmap toObservable . toObservable


-- | Merge two observables using a given merge function. Whenever one of the inputs is updated, the resulting
-- observable updates according to the merge function.
--
-- There is no caching involed, every subscriber effectively subscribes to both input observables.
data MergedObservable r o0 v0 o1 v1 = MergedObservable (v0 -> v1 -> r) o0 o1
instance forall r o0 v0 o1 v1. (IsRetrievable v0 o0, IsRetrievable v1 o1) => IsRetrievable r (MergedObservable r o0 v0 o1 v1) where
  retrieve (MergedObservable merge obs0 obs1) = liftA2 (liftA2 merge) (retrieve obs0) (retrieve obs1)
instance forall r o0 v0 o1 v1. (IsObservable v0 o0, IsObservable v1 o1) => IsObservable r (MergedObservable r o0 v0 o1 v1) where
  observe (MergedObservable merge obs0 obs1) callback = do
    var0 <- newTVarIO Nothing
    var1 <- newTVarIO Nothing
    d0 <- observe obs0 (mergeCallback var0 var1 . writeTVar var0 . Just)
    d1 <- observe obs1 (mergeCallback var0 var1 . writeTVar var1 . Just)
    pure $ mconcat [d0, d1]
    where
      mergeCallback :: TVar (Maybe (ObservableMessage v0)) -> TVar (Maybe (ObservableMessage v1)) -> STM () -> IO ()
      mergeCallback var0 var1 update = do
        mMerged <- atomically $ do
          update
          runMaybeT $ liftA2 (liftA2 merge) (MaybeT (readTVar var0)) (MaybeT (readTVar var1))

        -- Run the callback only once both values have been received
        mapM_ callback mMerged


-- | Merge two observables using a given merge function. Whenever one of the inputs is updated, the resulting
-- observable updates according to the merge function.
--
-- Behaves like `liftA2` on `Observable` but accepts anything that implements `IsObservable`..
--
-- There is no caching involed, every subscriber effectively subscribes to both input observables.
mergeObservable :: (IsObservable v0 o0, IsObservable v1 o1) => (v0 -> v1 -> r) -> o0 -> o1 -> Observable r
mergeObservable merge x y = Observable $ MergedObservable merge x y

data FnObservable v = FnObservable {
  retrieveFn :: forall m. MonadAsync m => m (Task v),
  observeFn :: (ObservableMessage v -> IO ()) -> IO Disposable
}
instance IsRetrievable v (FnObservable v) where
  retrieve o = retrieveFn o
instance IsObservable v (FnObservable v) where
  observe o = observeFn o
  mapObservable f FnObservable{retrieveFn, observeFn} = Observable $ FnObservable {
    retrieveFn = f <<$>> retrieveFn,
    observeFn = \listener -> observeFn (listener . fmap f)
  }

-- | Implement an Observable by directly providing functions for `retrieve` and `subscribe`.
fnObservable
  :: ((ObservableMessage v -> IO ()) -> IO Disposable)
  -> (forall m. MonadAsync m => m (Task v))
  -> Observable v
fnObservable observeFn retrieveFn = toObservable FnObservable{observeFn, retrieveFn}

-- | Implement an Observable by directly providing functions for `retrieve` and `subscribe`.
synchronousFnObservable
  :: forall v. ((ObservableMessage v -> IO ()) -> IO Disposable)
  -> IO v
  -> Observable v
synchronousFnObservable observeFn synchronousRetrieveFn = fnObservable observeFn retrieveFn
  where
    retrieveFn :: (forall m. MonadAsync m => m (Task v))
    retrieveFn = liftIO $ successfulTask <$> synchronousRetrieveFn


newtype ConstObservable v = ConstObservable v
instance IsRetrievable v (ConstObservable v) where
  retrieve (ConstObservable x) = pure $ pure x
instance IsObservable v (ConstObservable v) where
  observe (ConstObservable x) callback = do
    callback $ ObservableUpdate x
    pure noDisposable


newtype FailedObservable v = FailedObservable SomeException
instance IsRetrievable v (FailedObservable v) where
  retrieve (FailedObservable ex) = liftIO $ throwIO ex
instance IsObservable v (FailedObservable v) where
  observe (FailedObservable ex) callback = do
    callback $ ObservableNotAvailable ex
    pure noDisposable


-- | Create an observable by simply running an IO action whenever a value is requested or a callback is registered.
--
-- There is no mechanism to send more than one update, so the resulting `Observable` will only be correct in specific
-- situations.
unsafeObservableIO :: forall v. IO v -> Observable v
unsafeObservableIO action = synchronousFnObservable observeFn action
  where
    observeFn :: (ObservableMessage v -> IO ()) -> IO Disposable
    observeFn callback = do
      callback ObservableLoading
      value <- (ObservableUpdate <$> action) `catchAll` (pure . ObservableNotAvailable @v)
      callback value
      pure noDisposable


-- TODO implement
--cacheObservable :: IsObservable v o => o -> Observable v
--cacheObservable = undefined