package main

import (
	"context"
	"embed"
	"flag"
	"net/http"
	"os"
	"os/signal"
	"time"

	kitlog "github.com/go-kit/kit/log"
	"github.com/go-kit/kit/log/level"
	"github.com/jmoiron/sqlx"
	"modernc.org/ql"
	_ "modernc.org/ql/driver"

	"git.c3pb.de/gbe/invinoveritas/auth"
	"git.c3pb.de/gbe/invinoveritas/log"
	"git.c3pb.de/gbe/invinoveritas/session"
)

//go:embed templates/*.tpl
var templateFS embed.FS

//go:embed static/*
var staticFS embed.FS

func httpError(w http.ResponseWriter, r *http.Request, msg string, err error, status int) {
	if err != nil {
		msg += ": " + err.Error()
	}

	level.Error(log.Get(r)).
		Log("status", status,
			"error", err,
			"msg", msg)

	http.Error(w, msg, status)
}

type sessionProvider interface {
	ListSessions(context.Context) ([]session.Info, error)
	DeleteSession(ctx context.Context, token string) error
}

type userProvider interface {
	ListUsers(context.Context) ([]string, error)
	CreateUser(ctx context.Context, name string) (string, error)
	DeleteUser(ctx context.Context, name string) error

	UpdatePassword(ctx context.Context, name string, oldPW, newPW string) error
}

type Handler struct {
	sqlDB *sqlx.DB
	sp    sessionProvider
	up    userProvider
}

func addCacheHeaders(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Header().Add("Cache-Control", "public, max-age=86400, immutable")
		next.ServeHTTP(w, r)
	})
}

func openDB(path string, logger kitlog.Logger) (*sqlx.DB, error) {
	var retried bool

	path += ".ql"

retry:
	db, err := sqlx.Open("ql2", path)
	if err != nil {
		return nil, err
	}

	level.Info(logger).Log("msg", "running test-select on db")
	_, err = db.Exec(`SELECT count(*) FROM __Table`)
	if err != nil && !retried {
		level.Error(logger).
			Log("error", err,
				"msg", "got error")

		retried = true

		// WAL may be corrupted. Manually remove it and re-try open. See
		// https://gitlab.com/cznic/ql/-/issues/227 for more info.
		err := os.Remove(ql.WalName(path))
		if err != nil {
			return nil, err
		}

		goto retry
	}

	if err != nil {
		return nil, err
	}

	return db, nil
}

// wrapMiddleware wraps common middleware around hdlr:
// - log.Request to make a logger available in the request context
// - addCacheHeader for caching
// - auth.Require
func wrapMiddleware(hdlr http.Handler, sessions session.Provider, logger kitlog.Logger) http.Handler {
	authFailed := log.Request(addCacheHeaders(sessions.Handler(templateFS)), logger)
	return log.Request(addCacheHeaders(auth.Require(hdlr, authFailed, sessions)), logger)
}

func main() {
	dbPath := flag.String("db", "vino", "Path to database file")
	listenAddr := flag.String("listen", "127.0.0.1:7878", "Listening address")
	debug := flag.Bool("debug", false, "Enable debug logging")

	flag.Parse()

	logger := kitlog.NewLogfmtLogger(kitlog.NewSyncWriter(os.Stdout))
	logger = kitlog.With(logger, "ts", kitlog.DefaultTimestampUTC)

	filter := level.AllowInfo()
	if *debug {
		filter = level.AllowDebug()
	}

	logger = level.NewFilter(logger, filter)

	db, err := openDB(*dbPath, logger)
	if err != nil {
		logger.Log("error", err, "msg", "can't open DB")
		os.Exit(1)
	}
	defer db.Close()

	ctx, done := context.WithCancel(context.Background())
	defer done()

	err = initDB(ctx, db)
	if err != nil {
		logger.Log("wal_name", ql.WalName(*dbPath+".ql"),
			"error", err,
			"msg", "can't initalize DB")
		os.Exit(1)
	}

	http.HandleFunc("/favicon.ico", http.NotFound)

	http.Handle("/static/", log.Request(addCacheHeaders(http.FileServer(http.FS(staticFS))), logger))

	sessions := session.Provider{
		DB: db,
	}

	handler := Handler{
		sqlDB: db,
		sp:    sessions,
		up:    sessions,
	}

	http.Handle("/details/img/", wrapMiddleware(handler.img(), sessions, logger))
	http.Handle("/details/", wrapMiddleware(handler.details(), sessions, logger))
	http.Handle("/user/", wrapMiddleware(handler.user(), sessions, logger))
	http.Handle("/", wrapMiddleware(handler.index(), sessions, logger))

	srv := http.Server{
		Addr: *listenAddr,
	}

	level.Info(logger).
		Log("addr", "http://"+*listenAddr, "msg", "starting http server")

	errs := make(chan error, 1)

	go func() {
		errs <- http.ListenAndServe(*listenAddr, nil)
	}()

	// Wait for OS signal or error, shut down server if signal received
	sigChan := make(chan os.Signal, 1)
	signal.Notify(sigChan, os.Interrupt)

	select {
	case sig := <-sigChan:
		level.Info(logger).
			Log("signal", sig, "msg", "received signal, terminating")

		shutdownCtx, done := context.WithTimeout(context.Background(), 15*time.Second)
		defer done()

		err = srv.Shutdown(shutdownCtx)
		if err != nil {
			level.Error(logger).
				Log("error", err, "msg", "clean shut down failed")
		}
	case err = <-errs:
		level.Error(logger).
			Log("error", err, "msg", "HTTP server failed to start")
	}
}