package main import ( "context" "crypto/rand" "embed" "fmt" "io/fs" "sort" "github.com/google/uuid" "github.com/jmoiron/sqlx" log "github.com/sirupsen/logrus" bolt "go.etcd.io/bbolt" ) const expectDBVersion = "1" // initBoltDB initializes db with persistent state like the password salt and checks whether an already initialized database has the expected version number. func initBoltDB(db *bolt.DB) error { // Initialize persistent state like password salt err := db.Update(func(tx *bolt.Tx) error { bucket, err := tx.CreateBucketIfNotExists([]byte("meta")) if err != nil { return err } // Check database version storedVersion := bucket.Get([]byte("version")) if storedVersion != nil && string(storedVersion) != expectDBVersion { return fmt.Errorf("Unexpected database version: have %q, want %q", storedVersion, expectDBVersion) } // Make sure the db version is stored err = bucket.Put([]byte("version"), []byte(expectDBVersion)) if err != nil { return err } // Create new PW salt if it doesn't already exist pwSalt := bucket.Get([]byte("pwsalt")) if pwSalt == nil { pwSalt = make([]byte, 16) _, err := rand.Read(pwSalt) if err != nil { return err } err = bucket.Put([]byte("pwsalt"), pwSalt) if err != nil { return err } } return nil }) if err != nil { return err } return nil } //go:embed migrations/*.sql var migrationFS embed.FS func initSQLdb(ctx context.Context, db *sqlx.DB) error { // Make sure that we have a table for the migrations tx, err := db.BeginTxx(ctx, nil) if err != nil { return fmt.Errorf("creating tx: %w", err) } defer func() { if err != nil { tx.Rollback() return } tx.Commit() }() _, err = tx.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS migrations (name string)`) if err != nil { return fmt.Errorf("creating migrations table: %w", err) } _, err = tx.ExecContext(ctx, `CREATE UNIQUE INDEX IF NOT EXISTS migrations_name ON migrations (name)`) if err != nil { return fmt.Errorf("creating migrations index: %w", err) } entries, err := migrationFS.ReadDir("migrations") if err != nil { return fmt.Errorf("reading embedded migration FS: %w", err) } sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() }) for _, e := range entries { // Check if we need to apply that migration var count int err = tx.GetContext(ctx, &count, `SELECT count(*) FROM migrations WHERE name = ?1`, e.Name()) if err != nil { return fmt.Errorf("checking for migration %q: %w", e.Name(), err) } if count > 1 { return fmt.Errorf("migration %q applied more than once", e.Name()) } if count == 1 { continue } content, err := fs.ReadFile(migrationFS, "migrations/"+e.Name()) if err != nil { return fmt.Errorf("reading migration %q: %w", e.Name(), err) } _, err = tx.ExecContext(ctx, string(content)) if err != nil { return fmt.Errorf("applying migration %q: %w", e.Name(), err) } // Mark migration as applied _, err = tx.ExecContext(ctx, `INSERT INTO migrations (name) VALUES (?1)`, e.Name()) if err != nil { return fmt.Errorf("marking migration %q as applied: %w", e.Name(), err) } } return nil } func migrateBoltToQL(ctx context.Context, b *bolt.DB, ql *sqlx.DB) error { // Check if the ql DB has already been the target of a migration var count int err := ql.GetContext(ctx, &count, `SELECT count(*) FROM state WHERE key = "migrated-from-bolt" AND val = "true"`) if err != nil { return err } if count == 1 { log.Println("migration already done") return nil // Nothing to do here } tx, err := ql.Beginx() defer func() { if err != nil { log.Println("rolling back transaction") tx.Rollback() return } tx.Commit() }() // Migrate users err = migrateUsers(ctx, b, tx) if err != nil { return fmt.Errorf("migrating users: %w", err) } // Migrate wines err = migrateWines(ctx, b, tx) if err != nil { return fmt.Errorf("migrating wines: %w", err) } // Migrate state err = migrateState(ctx, b, tx) if err != nil { return fmt.Errorf("migrating state: %w", err) } // Mark migration complete _, err = tx.ExecContext(ctx, `INSERT INTO state (key, val) VALUES (?1, ?2)`, "migrated-from-bolt", "true") if err != nil { return fmt.Errorf("marking migration complete: %w", err) } log.Println("migration complete") return nil } func migrateUsers(ctx context.Context, b *bolt.DB, sqlTx *sqlx.Tx) error { err := b.View(func(tx *bolt.Tx) error { bucket := tx.Bucket([]byte("users")) if bucket == nil || bucket.Stats().KeyN == 0 { return nil // No users, nothing to migrate } err := bucket.ForEach(func(k, v []byte) error { name := string(k) pwhash := string(v) log.WithField("name", name). Info("migrating user") _, err := sqlTx.ExecContext(ctx, `INSERT INTO users (name, password) VALUES (?1, ?2)`, name, pwhash) return err }) return err }) return err } func migrateWines(ctx context.Context, b *bolt.DB, sqlTx *sqlx.Tx) error { err := b.View(func(tx *bolt.Tx) error { bucket := tx.Bucket([]byte("wines")) if bucket == nil || bucket.Stats().KeyN == 0 { return nil } err := bucket.ForEach(func(k, d []byte) error { if d != nil { return fmt.Errorf("%q not a bucket", k) } u, err := uuid.ParseBytes(k) if err != nil { return err } v, err := loadVinoBolt(tx, u) if err != nil { return err } log.WithFields(log.Fields{ "uuid": u.String(), "name": v.Name, }).Info("migrating wine") data := bucket.Bucket(k) if data == nil { return fmt.Errorf("no data for %q", k) } rawPicture := data.Get([]byte("picture")) res, err := sqlTx.ExecContext(ctx, `INSERT INTO wines (name, rating, picture, country) VALUES (?1, ?2, ?3, ?4)`, v.Name, v.Rating, rawPicture, string(v.Country[:])) if err != nil { return err } id, err := res.LastInsertId() if err != nil { return err } // Insert comments for _, c := range v.Comments { _, err := sqlTx.ExecContext(ctx, `INSERT INTO comments (wine, content) VALUES (?1, ?2)`, id, c.Content) if err != nil { return err } } return nil }) return err }) return err } func migrateState(ctx context.Context, b *bolt.DB, sqlTx *sqlx.Tx) error { err := b.View(func(tx *bolt.Tx) error { for _, name := range []string{"state", "meta"} { bucket := tx.Bucket([]byte(name)) if bucket == nil || bucket.Stats().KeyN == 0 { continue } log.WithFields(log.Fields{ "name": name, "keys": bucket.Stats().KeyN, }).Info("migrating metadata bucket") err := bucket.ForEach(func(k, v []byte) error { key := string(k) val := string(v) _, err := sqlTx.ExecContext(ctx, `INSERT INTO state (key, val) VALUES (?1, ?2)`, key, val) return err }) if err != nil { return err } } return nil }) return err }