package main

import (
	"context"
	"embed"
	"fmt"
	"io/fs"
	"sort"

	"github.com/jmoiron/sqlx"
	"github.com/sirupsen/logrus"
)

//go:embed migrations/*.sql
var migrationFS embed.FS

func initDB(ctx context.Context, db *sqlx.DB) error {
	// Make sure that we have a table for the migrations

	tx, err := db.BeginTxx(ctx, nil)
	if err != nil {
		return fmt.Errorf("creating tx: %w", err)
	}

	defer func() {
		if err != nil {
			rerr := tx.Rollback()
			if rerr != nil {
				logrus.WithError(rerr).
					Error("can't roll back transaction")
			}

			return
		}

		err = tx.Commit()
	}()

	_, err = tx.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS migrations (name string)`)
	if err != nil {
		return fmt.Errorf("creating migrations table: %w", err)
	}

	_, err = tx.ExecContext(ctx, `CREATE UNIQUE INDEX IF NOT EXISTS migrations_name ON migrations (name)`)
	if err != nil {
		return fmt.Errorf("creating migrations index: %w", err)
	}

	entries, err := migrationFS.ReadDir("migrations")
	if err != nil {
		return fmt.Errorf("reading embedded migration FS: %w", err)
	}

	sort.Slice(entries, func(i, j int) bool {
		return entries[i].Name() < entries[j].Name()
	})

	for _, e := range entries {
		// Check if we need to apply that migration
		var count int
		err = tx.GetContext(ctx, &count, `SELECT count(*) FROM migrations WHERE name = ?1`, e.Name())
		if err != nil {
			return fmt.Errorf("checking for migration %q: %w", e.Name(), err)
		}

		if count > 1 {
			return fmt.Errorf("migration %q applied more than once", e.Name())
		}

		if count == 1 {
			continue
		}

		content, err := fs.ReadFile(migrationFS, "migrations/"+e.Name())
		if err != nil {
			return fmt.Errorf("reading migration %q: %w", e.Name(), err)
		}

		_, err = tx.ExecContext(ctx, string(content))
		if err != nil {
			return fmt.Errorf("applying migration %q: %w", e.Name(), err)
		}

		// Mark migration as applied
		_, err = tx.ExecContext(ctx, `INSERT INTO migrations (name) VALUES (?1)`, e.Name())
		if err != nil {
			return fmt.Errorf("marking migration %q as applied: %w", e.Name(), err)
		}
	}

	return nil
}