diff --git a/main.go b/main.go index e4ee0c9d0fd3a83ff284ad7a42bce4ced6ad7156..df38df80c10d900a09cc6617e5fade131b8e53c8 100644 --- a/main.go +++ b/main.go @@ -11,13 +11,13 @@ import ( kitlog "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" - "modernc.org/ql" _ "modernc.org/ql/driver" "git.c3pb.de/gbe/invinoveritas/auth" "git.c3pb.de/gbe/invinoveritas/log" "git.c3pb.de/gbe/invinoveritas/session" "git.c3pb.de/gbe/invinoveritas/storage" + "git.c3pb.de/gbe/invinoveritas/storage/query" "git.c3pb.de/gbe/invinoveritas/vino" ) @@ -114,23 +114,15 @@ func main() { level.Info(logger). Log("commit", commitHash, "build", buildTime) - db, err := storage.Open(*dbPath, logger) - if err != nil { - logger.Log("error", err, "msg", "can't open DB") - os.Exit(1) - } - defer db.Close() - ctx, done := context.WithCancel(context.Background()) defer done() - err = storage.InitDB(ctx, db) + db, err := storage.Open(ctx, *dbPath, logger) if err != nil { - level.Error(logger). - Log("wal_name", ql.WalName(*dbPath+".ql"), - "error", err, "msg", "can't initalize DB") + logger.Log("error", err, "msg", "can't open DB") os.Exit(1) } + defer db.Close() logged404 := log.Request(http.HandlerFunc(http.NotFound), kitlog.With(logger, "code", 404)) @@ -141,6 +133,7 @@ func main() { sessions := session.Provider{ DB: db, + Q: query.New(db), } handler := Handler{ diff --git a/session/authprovider.go b/session/authprovider.go index fcc680b397b7a56933d22526a086b3182be0586f..37512b8fdbd00d08db6b27d3157bae55a7cda835 100644 --- a/session/authprovider.go +++ b/session/authprovider.go @@ -20,20 +20,19 @@ import ( "git.c3pb.de/gbe/invinoveritas/auth" "git.c3pb.de/gbe/invinoveritas/log" + "git.c3pb.de/gbe/invinoveritas/storage/query" ) -func hashPassword(ctx context.Context, db *sqlx.DB, user, pass string) (string, error) { +func hashPassword(ctx context.Context, q *query.Queries, user, pass string) (string, error) { // Look up password salt from DB and hash password with it - var salt []byte - - err := db.GetContext(ctx, &salt, `SELECT val FROM state WHERE key = "pwsalt"`) + salt, err := q.GetSalt(ctx) if err != nil { return "", fmt.Errorf("loading password salt: %w", err) } h := sha256.New() - _, err = h.Write(salt) + _, err = h.Write([]byte(salt)) if err != nil { return "", err } @@ -51,36 +50,39 @@ type Info struct { type Provider struct { DB *sqlx.DB + Q *query.Queries } // Valid looks up the provided session token in a's database and returns the appropriate user if the token is // valid. If it is not, an error is returned. func (a Provider) Valid(ctx context.Context, token string) (*auth.User, error) { - var user auth.User + var ( + user auth.User + err error + ) - err := a.DB.GetContext(ctx, &user.Name, ` - SELECT users.name - FROM users, sessions - WHERE sessions.user == id(users) && sessions.token = ?1`, token) - if errors.Is(err, sql.ErrNoRows) { - // Let's see if there are any users at all. If not, we let 'em in. - var count int - err = a.DB.GetContext(ctx, &count, `SELECT count(*) FROM users`) - if err != nil { - return nil, err - } + err = a.Q.RunTx(ctx, func(q *query.Queries) error { + user.Name, err = q.IsValidSession(ctx, token) + if errors.Is(err, sql.ErrNoRows) { + // Let's see if there are any users at all. If not, we let 'em in. + count, err := q.CountUsers(ctx) + if err != nil { + return err + } - // No users in DB, allow everyone - if count == 0 { - return &user, nil - } + // No users in DB, allow everyone + if count == 0 { + return nil + } - return nil, auth.ErrAuthFailed - } + return auth.ErrAuthFailed + } + return nil + }) if err != nil { level.Error(log.GetContext(ctx)). - Log("error", err, "token", token, "msg", "got error during user lookup") + Log("error", err, "token", token, "msg", "got error during session lookup") return nil, err } @@ -123,99 +125,87 @@ func (a Provider) Handler(templateFS fs.FS) http.Handler { level.Debug(log). Log("pass", password, "debug", "attempting auth") - // Get salted password from DB - var userData struct { - ID int `db:"id"` - Salted string `db:"password"` - } + var token string - err := a.DB.GetContext(r.Context(), &userData, - `SELECT id() AS id, password FROM users WHERE name = $1`, - userName) + err := a.Q.RunTx(r.Context(), func(q *query.Queries) error { + userData, err := q.GetAuthData(r.Context(), userName) + if errors.Is(err, sql.ErrNoRows) { + // Either no such user or no users at all + count, err := q.CountUsers(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return fmt.Errorf("getting user count: %w", err) + } - if errors.Is(err, sql.ErrNoRows) { - // Either no such user or no users at all - var count int - err = a.DB.GetContext(r.Context(), &count, `SELECT count(*) FROM users`) - if err != nil { - level.Error(log). - Log("error", err, "msg", "can't get user count") - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + level.Debug(log). + Log("num_users", count, "msg", "got user count") + + // TODO: Deal with this. Report an error if != 0 users in db, otherwise allow them through. - level.Debug(log). - Log("num_users", count, "msg", "got user count") + w.WriteHeader(http.StatusUnauthorized) - // TODO: Deal with this. Report an error if != 0 users in db, otherwise allow them through. + rd := responseData{ + Error: "authentication failed", + } - w.WriteHeader(http.StatusUnauthorized) + err = tpl.ExecuteTemplate(w, "auth.tpl", rd) + if err != nil { + return fmt.Errorf("executing auth template: %w", err) + } - rd := responseData{ - Error: "authentication failed", + return nil + } + if err != nil { + return fmt.Errorf("getting salted password: %w", err) } - err := tpl.ExecuteTemplate(w, "auth.tpl", rd) + hashed, err := hashPassword(r.Context(), q, userName, password) if err != nil { - level.Error(log). - Log("error", err, "msg", "can't execute auth template") + http.Error(w, err.Error(), http.StatusInternalServerError) + return fmt.Errorf("hashing password: %w", err) } - return - } + if hashed != userData.Password { + level.Error(log). + Log("msg", "password mismatch") - if err != nil { - level.Error(log). - Log("error", err, "msg", "can't get salted password") - } + w.WriteHeader(http.StatusUnauthorized) - hashed, err := hashPassword(r.Context(), a.DB, userName, password) - if err != nil { - level.Error(log). - Log("error", err, "msg", "can't hash password") + rd := responseData{ + Error: "authentication failed", + } - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + err := tpl.ExecuteTemplate(w, "auth.tpl", rd) + if err != nil { + return fmt.Errorf("executing auth template: %w", err) + } + + return nil + } - if hashed != userData.Salted { - level.Error(log). - Log("msg", "password mismatch") + log = kitlog.With(log, "user_id", userData.UserID) - w.WriteHeader(http.StatusUnauthorized) + // All good + level.Info(log). + Log("msg", "auth looks good") - rd := responseData{ - Error: "authentication failed", + addr := r.Header.Get("X-Forwarded-For") + if addr == "" { + addr = r.RemoteAddr } - err := tpl.ExecuteTemplate(w, "auth.tpl", rd) + token, err = a.createSession(r.Context(), q, int(userData.UserID), addr) if err != nil { - level.Error(log). - Log("error", err, "msg", "can't execute auth template") + http.Error(w, err.Error(), http.StatusInternalServerError) + return fmt.Errorf("creating session token: %w", err) } - return - } - - log = kitlog.With(log, "user_id", userData.ID) - - // All good - level.Info(log). - Log("msg", "auth looks good") - - addr := r.Header.Get("X-Forwarded-For") - if addr == "" { - addr = r.RemoteAddr - } - - token, err := a.createSession(r.Context(), userData.ID, addr) + return nil + }) if err != nil { - level.Error(log). - Log("error", err, - "msg", "can't create session token") - - http.Error(w, err.Error(), http.StatusInternalServerError) + level.Error(log).Log("msg", "authentication failed", "error", err) + // Response already written to client. return } @@ -236,7 +226,7 @@ func (a Provider) Handler(templateFS fs.FS) http.Handler { }) } -func (a Provider) createSession(ctx context.Context, userID int, remoteAddr string) (string, error) { +func (a Provider) createSession(ctx context.Context, q *query.Queries, userID int, remoteAddr string) (string, error) { var rawToken [40]byte n, err := crand.Read(rawToken[:]) @@ -328,70 +318,67 @@ func (a Provider) ListUsers(ctx context.Context) ([]string, error) { return names, err } -func (a Provider) UpdatePassword(ctx context.Context, userName, passOld, passNew string) (err error) { - hashedOld, err := hashPassword(ctx, a.DB, userName, passOld) - if err != nil { - return err - } - - hashedNew, err := hashPassword(ctx, a.DB, userName, passNew) - if err != nil { - return err - } - - tx, err := a.DB.Beginx() - if err != nil { - return err - } - defer func() { +func (a Provider) UpdatePassword(ctx context.Context, userName, passOld, passNew string) error { + err := a.Q.RunTx(ctx, func(q *query.Queries) error { + hashedOld, err := hashPassword(ctx, q, userName, passOld) if err != nil { - rerr := tx.Rollback() - if rerr != nil { - level.Error(log.GetContext(ctx)). - Log("error", rerr, "msg", "can't roll back transaction") - } - - return + return err } - err = tx.Commit() - }() - - var pwHash []byte - err = tx.GetContext(ctx, &pwHash, `SELECT password FROM users WHERE name = ?1`, userName) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return err - } + hashedNew, err := hashPassword(ctx, q, userName, passNew) + if err != nil { + return err + } - if len(pwHash) != 0 && string(pwHash) != hashedOld { - // A password exists already and it's not the one the user gave as their 'old' password. - return auth.ErrAuthFailed - } + authData, err := q.GetAuthData(ctx, userName) + if err != nil { + return err + } - level.Info(log.GetContext(ctx)). - Log("name", userName, "msg", "updating password") + if len(authData.Password) != 0 && authData.Password != hashedOld { + // A password exists already and it's not the one the user gave as their 'old' password. + return auth.ErrAuthFailed + } - // Either user doesn't exist yet (no entry) or the password they gave as their old matched. We - // can update. - // Try an update first, if that fails, an insert - res, err := tx.ExecContext(ctx, `UPDATE users SET password = ?1 WHERE name = ?2`, hashedNew, userName) - if err != nil { - return fmt.Errorf("updating password: %w", err) - } + level.Info(log.GetContext(ctx)). + Log("name", userName, "msg", "updating password") + + // Either user doesn't exist yet (no entry) or the password they gave as their old matched. We + // can update. + // Try an update first, if that fails, an insert + res, err := q.UpdateUser(ctx, query.UpdateUserParams{ + Name: userName, + Password: hashedNew, + }) - n, err := res.RowsAffected() - if err != nil { - return err - } + n, err := res.RowsAffected() + if err != nil { + return err + } + if n > 1 { + return fmt.Errorf("more than one user updated: %d", n) + } + if n > 0 { + return nil + } - if n == 0 { level.Info(log.GetContext(ctx)). Log("name", userName, "msg", "creating new user") - _, err = tx.ExecContext(ctx, `INSERT INTO users (name, password) VALUES (?1, ?2)`, userName, hashedNew) + err = q.InsertUser(ctx, query.InsertUserParams{ + Name: userName, + Password: hashedNew, + }) if err != nil { return err } + + return nil + }) + if err != nil { + level.Error(log.GetContext(ctx)). + Log("msg", "user update failed", "error", err) + return err } return nil diff --git a/session/authprovider_test.go b/session/authprovider_test.go index f2f11fb4949667b625ad844b4e8f5704ec726cc4..210aba4602dd34c92737670810c63e52504207f7 100644 --- a/session/authprovider_test.go +++ b/session/authprovider_test.go @@ -4,8 +4,9 @@ import ( "context" "testing" - "github.com/jmoiron/sqlx" - _ "modernc.org/ql/driver" + "git.c3pb.de/gbe/invinoveritas/storage" + "git.c3pb.de/gbe/invinoveritas/storage/query" + _ "modernc.org/sqlite" ) func assertNoError(t *testing.T, err error) { @@ -14,34 +15,38 @@ func assertNoError(t *testing.T, err error) { } } -func initDB(ctx context.Context, t *testing.T, db *sqlx.DB) { - tx, err := db.Begin() - assertNoError(t, err) - - _, err = tx.ExecContext(ctx, `CREATE TABLE state (key string, val string)`) - assertNoError(t, err) - - _, err = tx.ExecContext(ctx, `INSERT INTO state (key, val) VALUES (?1, ?2)`, "pwsalt", "1234") - assertNoError(t, err) +type testLog struct { + t *testing.T +} - err = tx.Commit() - assertNoError(t, err) +func (t testLog) Log(p ...interface{}) error { + t.t.Log(p...) + return nil } func TestHashPassword(t *testing.T) { ctx := context.Background() - db, err := sqlx.Open("ql-mem", "unit-test") - assertNoError(t, err) + db, err := storage.Open(ctx, "unit-test.sqlite", testLog{t: t}) + if err != nil { + t.Fatal("unexpected error:", err) + } defer db.Close() - initDB(ctx, t, db) + q := query.New(db) - h1, err := hashPassword(ctx, db, "test-user", "a nice password") + err = q.InsertSalt(ctx, "unit-test") assertNoError(t, err) - h2, err := hashPassword(ctx, db, "another-user", "a nice password") - assertNoError(t, err) + h1, err := hashPassword(ctx, q, "test-user", "a nice password") + if err != nil { + t.Fatal("unexpected error:", err) + } + + h2, err := hashPassword(ctx, q, "another-user", "a nice password") + if err != nil { + t.Fatal("unexpected error:", err) + } if h1 == h2 { t.Error("did not expect equal hashes") diff --git a/storage/db_test.go b/storage/db_test.go index fde682232f9e67d6ce875451485679c0c65b9ed8..64a9d711073355e4d96b645561b22399234572b2 100644 --- a/storage/db_test.go +++ b/storage/db_test.go @@ -1,11 +1,9 @@ package storage import ( - "context" "testing" - "github.com/jmoiron/sqlx" - _ "modernc.org/ql/driver" + _ "modernc.org/sqlite" ) func assertNoError(t *testing.T, err error) { @@ -25,150 +23,3 @@ func expectError(t *testing.T, err error) { t.Error("expected an error") } } - -func TestInitSQLdb(t *testing.T) { - ctx := context.Background() - - db, err := sqlx.Open("ql-mem", "unit-test") - assertNoError(t, err) - defer db.Close() - - // Try to init the DB twice. No call should error, and we should have a migrations - // table afterwards. - - err = InitDB(ctx, db) - assertNoError(t, err) - - err = InitDB(ctx, db) - assertNoError(t, err) - - // Assert that there is one migration - var count int - err = db.Get(&count, `SELECT count(*) FROM migrations`) - assertNoError(t, err) - - if count != 3 { - t.Errorf("unexpected number of applied migrations: %d, want 3", count) - } - - // Try to insert the same migration twice, ensure that the second insert fails - tx, err := db.Beginx() - assertNoError(t, err) - - _, err = tx.ExecContext(ctx, `INSERT INTO migrations VALUES (?1)`, "test") - assertNoError(t, err) - - err = tx.Commit() - assertNoError(t, err) - - tx, err = db.Beginx() - assertNoError(t, err) - - _, err = tx.ExecContext(ctx, `INSERT INTO migrations VALUES (?1)`, "test") - if err == nil { - t.Error("expected an error, got nil") - } - - err = tx.Rollback() - assertNoError(t, err) -} - -func TestInitialDBStructure(t *testing.T) { - ctx := context.Background() - - db, err := sqlx.Open("ql-mem", "unit-test") - assertNoError(t, err) - defer db.Close() - - // Try to init the DB twice. No call should error, and we should have a migrations - // table afterwards. - - err = InitDB(ctx, db) - assertNoError(t, err) - - // Run a few tests on the DB structure: - t.Run("wines-country", func(t *testing.T) { - // This tests the "country" column on wines - tx, err := db.Begin() - assertNoError(t, err) - defer tx.Rollback() //nolint:errcheck - - // - 2 characters - _, err = tx.Exec(`INSERT INTO wines (name, country) VALUES (?1, ?2)`, "test-wine", "DE") - expectNoError(t, err) - - // - more than 2 characters - _, err = tx.Exec(`INSERT INTO wines (name, country) VALUES (?1, ?2)`, "test-wine", "foobar") - expectError(t, err) - - // - less than 2 characters - _, err = tx.Exec(`INSERT INTO wines (name, country) VALUES (?1, ?2)`, "test-wine", "a") - expectError(t, err) - - // - null - _, err = tx.Exec(`INSERT INTO wines (name) VALUES (?1)`, "test-wine") - expectNoError(t, err) - }) - - t.Run("wine-comment", func(t *testing.T) { - // This test ensures that to insert a comment, there has to be a matching wine entry - tx, err := db.Begin() - assertNoError(t, err) - defer tx.Rollback() //nolint:errcheck - - res, err := tx.ExecContext(ctx, `INSERT INTO wines (name) VALUES (?1)`, "test-wine") - assertNoError(t, err) - - id, err := res.LastInsertId() - assertNoError(t, err) - - if id == 0 { - t.Fatal("unexpected insert id, want anything but 0") - } - - // - insert a comment with no existing wine - _, err = tx.ExecContext(ctx, `INSERT INTO comments (content, wine) VALUES(?1, ?2)`, "test!", id+1) - if err == nil { - t.Error("expected error when adding comment without wine") - } - - // - insert comment for the test wine - _, err = tx.ExecContext(ctx, `INSERT INTO comments (content, wine) VALUES(?1, ?2)`, "test!", id) - assertNoError(t, err) - - // - delete test wine - _, err = tx.ExecContext(ctx, `DELETE FROM wines`) - assertNoError(t, err) // Actually, I want this to error since there's a comment that references this wine. - - // - delete test comment - _, err = tx.ExecContext(ctx, `DELETE FROM comments`) - assertNoError(t, err) - }) - - t.Run("state", func(t *testing.T) { - tx, err := db.Begin() - assertNoError(t, err) - defer tx.Rollback() //nolint:errcheck - - // Insert the same key twice - _, err = tx.Exec(`INSERT INTO state (key, val) VALUES (?1, ?2)`, "test", "fnord") - expectNoError(t, err) - - _, err = tx.Exec(`INSERT INTO state (key, val) VALUES (?1, ?2)`, "test", "fnord") - expectError(t, err) - }) - - t.Run("users", func(t *testing.T) { - // This test ensures that to insert a comment, there has to be a matching wine entry - tx, err := db.Begin() - assertNoError(t, err) - defer tx.Rollback() //nolint:errcheck - - // Insert the same user twice - _, err = tx.Exec(`INSERT INTO users (name, password) VALUES (?1, ?2)`, "test", "fnord") - expectNoError(t, err) - - _, err = tx.Exec(`INSERT INTO users (name, password) VALUES (?1, ?2)`, "test", "fnord") - expectError(t, err) - }) -} diff --git a/storage/migrations/0001-initial.sql b/storage/migrations/0001-initial.sql index a045e1b133a56c34875aebe83f11d474d8274eb9..d6cc7422ea8cd858fde9cfd6f8e085488ded025e 100644 --- a/storage/migrations/0001-initial.sql +++ b/storage/migrations/0001-initial.sql @@ -1,25 +1,42 @@ CREATE TABLE wines ( - name string, - rating int, -- number of stars - picture blob, -- jpeg/png image of the label on the bottle - country string (len(country) == 2 || country IS NULL) -- ISO2 country code + name text, + rating int, -- number of stars + picture blob, -- jpeg/png image of the label on the bottle + country text -- ISO2 country code ); CREATE TABLE comments ( - content string, - wine int (wine IN (SELECT id(wines) FROM wines)) + content text, + wine int, + + foreign key (wine) references wines (rowid) ); CREATE TABLE state ( - key string, - val string -); + key text not null, + val text not null, -CREATE UNIQUE INDEX state_key ON state (key); + unique(key) +); CREATE TABLE users ( - name string, - password string -- salted and hashed + userID integer primary key not null, + name text not null, + password text not null, -- salted and hashed + + unique(name) +); + +CREATE TABLE sessions ( + userID int not null, -- User this session is valid for + token text not null, -- Text of session token, to be stored in user cookiejar + created time not null, -- Creation time of the session + remote string not null, -- remote address that created the session + + foreign key (userID) references users (userID) ); -CREATE UNIQUE INDEX users_name ON users (name); \ No newline at end of file +CREATE TABLE IF NOT EXISTS migrations ( + name text not null, + unique(text) +); \ No newline at end of file diff --git a/storage/migrations/0002-session.sql b/storage/migrations/0002-session.sql deleted file mode 100644 index 2217340560a2ba420e98e115a11e68d249bac0ee..0000000000000000000000000000000000000000 --- a/storage/migrations/0002-session.sql +++ /dev/null @@ -1,5 +0,0 @@ -CREATE TABLE sessions ( - user int (user IN (SELECT id(users) FROM users)), -- User this session is valid for - token string, -- Text of session token, to be stored in user cookiejar - created time, -- Creation time of the session -); \ No newline at end of file diff --git a/storage/migrations/0003-session-remote.sql b/storage/migrations/0003-session-remote.sql deleted file mode 100644 index e4b6c3245390af2b342f89984db7d370cf46fbbc..0000000000000000000000000000000000000000 --- a/storage/migrations/0003-session-remote.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Add a "remote IP address" column to the session table -ALTER TABLE sessions ADD remote string; \ No newline at end of file diff --git a/storage/query/concat-schema.sh b/storage/query/concat-schema.sh new file mode 100755 index 0000000000000000000000000000000000000000..abcfadab180d5822d20dee3da7eba5524676b68b --- /dev/null +++ b/storage/query/concat-schema.sh @@ -0,0 +1,7 @@ +#!/bin/sh +set -e + +find ../migrations -type f -name '*.sql' | sort | while read f; do + echo "-- $f" + cat "$f" +done > $1 diff --git a/storage/query/db.go b/storage/query/db.go new file mode 100644 index 0000000000000000000000000000000000000000..890a04a7aaa1d6d1ecd04758705c97bc3fdf43c3 --- /dev/null +++ b/storage/query/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package query + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/storage/query/generate.go b/storage/query/generate.go new file mode 100644 index 0000000000000000000000000000000000000000..acea161afb3b0b4eee1d254302bbcf9864e62413 --- /dev/null +++ b/storage/query/generate.go @@ -0,0 +1,3 @@ +//go:generate ./concat-schema.sh schema.sql +//go:generate sqlc generate +package query diff --git a/storage/query/models.go b/storage/query/models.go new file mode 100644 index 0000000000000000000000000000000000000000..6d3458774044011f1c1f1660c88f6a90547c5f4b --- /dev/null +++ b/storage/query/models.go @@ -0,0 +1,41 @@ +// Code generated by sqlc. DO NOT EDIT. + +package query + +import ( + "database/sql" +) + +type Comment struct { + Content sql.NullString + Wine sql.NullInt32 +} + +type Migration struct { + Name string +} + +type Session struct { + UserID int32 + Token string + Created string + Remote string +} + +type State struct { + Key string + Val string +} + +type User struct { + UserID int32 + Name string + Password string +} + +type Wine struct { + Name sql.NullString + Rating sql.NullInt32 + Picture []byte + Country sql.NullString +} diff --git a/storage/query/query.sql b/storage/query/query.sql new file mode 100644 index 0000000000000000000000000000000000000000..2f04b63f8cfc9fc6eb25ae5e4308eb2cfefbba2a --- /dev/null +++ b/storage/query/query.sql @@ -0,0 +1,33 @@ +-- name: AddMigration :exec +insert into migrations (name) values (@name); + +-- name: CountMigrations :one +select count(*) from migrations where name = @name; + +---- +---- Session and user management +---- + +-- name: IsValidSession :one +select users.name +from users +join sessions using (userID) +where sessions.token = @token; + +-- name: CountUsers :one +select count(*) from users; + +-- name: GetAuthData :one +select userID, password from users where name = @name; + +-- name: GetSalt :one +select val from state where key = 'pwsalt'; + +-- name: InsertSalt :exec +insert into state (key, val) values ('pwsalt', @salt); + +-- name: UpdateUser :execresult +update users set password = @password where name = @name; + +-- name: InsertUser :exec +insert into users (name, password) values (@name, @password); diff --git a/storage/query/query.sql.go b/storage/query/query.sql.go new file mode 100644 index 0000000000000000000000000000000000000000..f9cd16a9f69cebd54fa45b414a8096432d880e35 --- /dev/null +++ b/storage/query/query.sql.go @@ -0,0 +1,121 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package query + +import ( + "context" + "database/sql" +) + +const addMigration = `-- name: AddMigration :exec +insert into migrations (name) values ($1) +` + +func (q *Queries) AddMigration(ctx context.Context, name string) error { + _, err := q.db.ExecContext(ctx, addMigration, name) + return err +} + +const countMigrations = `-- name: CountMigrations :one +select count(*) from migrations where name = $1 +` + +func (q *Queries) CountMigrations(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, countMigrations, name) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countUsers = `-- name: CountUsers :one +select count(*) from users +` + +func (q *Queries) CountUsers(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, countUsers) + var count int64 + err := row.Scan(&count) + return count, err +} + +const getAuthData = `-- name: GetAuthData :one +select userID, password from users where name = $1 +` + +type GetAuthDataRow struct { + UserID int32 + Password string +} + +func (q *Queries) GetAuthData(ctx context.Context, name string) (GetAuthDataRow, error) { + row := q.db.QueryRowContext(ctx, getAuthData, name) + var i GetAuthDataRow + err := row.Scan(&i.UserID, &i.Password) + return i, err +} + +const getSalt = `-- name: GetSalt :one +select val from state where key = 'pwsalt' +` + +func (q *Queries) GetSalt(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, getSalt) + var val string + err := row.Scan(&val) + return val, err +} + +const insertSalt = `-- name: InsertSalt :exec +insert into state (key, val) values ('pwsalt', $1) +` + +func (q *Queries) InsertSalt(ctx context.Context, salt string) error { + _, err := q.db.ExecContext(ctx, insertSalt, salt) + return err +} + +const insertUser = `-- name: InsertUser :exec +insert into users (name, password) values ($1, $2) +` + +type InsertUserParams struct { + Name string + Password string +} + +func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) error { + _, err := q.db.ExecContext(ctx, insertUser, arg.Name, arg.Password) + return err +} + +const isValidSession = `-- name: IsValidSession :one + +select users.name +from users +join sessions using (userID) +where sessions.token = $1 +` + +//-- +//-- Session and user management +//-- +func (q *Queries) IsValidSession(ctx context.Context, token string) (string, error) { + row := q.db.QueryRowContext(ctx, isValidSession, token) + var name string + err := row.Scan(&name) + return name, err +} + +const updateUser = `-- name: UpdateUser :execresult +update users set password = $1 where name = $2 +` + +type UpdateUserParams struct { + Password string + Name string +} + +func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (sql.Result, error) { + return q.db.ExecContext(ctx, updateUser, arg.Password, arg.Name) +} diff --git a/storage/query/schema.sql b/storage/query/schema.sql new file mode 100644 index 0000000000000000000000000000000000000000..b6376d4bff4d3b60241cdf3fd27c455539ee3c8b --- /dev/null +++ b/storage/query/schema.sql @@ -0,0 +1,43 @@ +-- ../migrations/0001-initial.sql +CREATE TABLE wines ( + name text, + rating int, -- number of stars + picture blob, -- jpeg/png image of the label on the bottle + country text -- ISO2 country code +); + +CREATE TABLE comments ( + content text, + wine int, + + foreign key (wine) references wines (rowid) +); + +CREATE TABLE state ( + key text not null, + val text not null, + + unique(key) +); + +CREATE TABLE users ( + userID integer primary key not null, + name text not null, + password text not null, -- salted and hashed + + unique(name) +); + +CREATE TABLE sessions ( + userID int not null, -- User this session is valid for + token text not null, -- Text of session token, to be stored in user cookiejar + created time not null, -- Creation time of the session + remote string not null, -- remote address that created the session + + foreign key (userID) references users (userID) +); + +CREATE TABLE IF NOT EXISTS migrations ( + name text not null, + unique(text) +); \ No newline at end of file diff --git a/storage/query/sqlc.yaml b/storage/query/sqlc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..782b93bdc5f33ac0656a7db3530ea6e438ff62e3 --- /dev/null +++ b/storage/query/sqlc.yaml @@ -0,0 +1,14 @@ +version: 1 +packages: + - path: "." + name: "query" + # Hack, because there's no SQlite driver, but Postgres is mostly close enough (tm) + engine: "postgresql" + schema: "schema.sql" + queries: "query.sql" +overrides: + - go_type: "string" + db_type: "pg_catalog.time" +rename: + rowid: "RowID" + userid: "UserID" \ No newline at end of file diff --git a/storage/query/wraptx.go b/storage/query/wraptx.go new file mode 100644 index 0000000000000000000000000000000000000000..de0e32d6ed14cd4abb5dfa1f447b0e4bdda4e345 --- /dev/null +++ b/storage/query/wraptx.go @@ -0,0 +1,43 @@ +package query + +import ( + "context" + "database/sql" +) + +type TXFunc func(q *Queries) error + +type BeginTxer interface { + BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error) +} + +func (q *Queries) Raw() DBTX { + return q.db +} + +func (q *Queries) RunTx(ctx context.Context, txf TXFunc) (err error) { + db := q.db.(BeginTxer) + + var tx *sql.Tx + + tx, err = db.BeginTx(ctx, nil) + if err != nil { + return err + } + + defer func() { + if err != nil { + tx.Rollback() + return + } + + err = tx.Commit() + }() + + err = txf(q.WithTx(tx)) + if err != nil { + return err + } + + return nil +} diff --git a/storage/storage.go b/storage/storage.go index e317e1bdc2a4dcdf2b00f72d45fa9f9c9d10da5d..4159b45df4a9da05184c00843d23899522ed9d37 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -10,17 +10,16 @@ import ( "image" "image/png" "io/fs" - "os" "sort" "git.c3pb.de/gbe/invinoveritas/log" + "git.c3pb.de/gbe/invinoveritas/storage/query" "git.c3pb.de/gbe/invinoveritas/vino" "github.com/Masterminds/squirrel" kitlog "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" "github.com/jmoiron/sqlx" "golang.org/x/image/draw" - "modernc.org/ql" ) var errNotFound = errors.New("not found") @@ -308,121 +307,82 @@ func (b Backend) Store(ctx context.Context, v *vino.Vino) (err error) { return nil } -func Open(path string, logger kitlog.Logger) (*sqlx.DB, error) { - var retried bool - - path += ".ql" +//go:embed migrations/*.sql +var migrationFS embed.FS -retry: - db, err := sqlx.Open("ql2", path) +func Open(ctx context.Context, path string, logger kitlog.Logger) (*sqlx.DB, error) { + db, err := sqlx.Open("sqlite", path) if err != nil { return nil, err } - level.Info(logger).Log("msg", "running test-select on db") - _, err = db.Exec(`SELECT count(*) FROM __Table`) - if err != nil && !retried { - level.Error(logger). - Log("error", err, - "msg", "got error") - - retried = true - - // WAL may be corrupted. Manually remove it and re-try open. See - // https://gitlab.com/cznic/ql/-/issues/227 for more info. - err := os.Remove(ql.WalName(path)) - if err != nil { - return nil, err - } - - goto retry + pragmas := []string{ + "PRAGMA busy_timeout = 600000", + "PRAGMA journal_mode = WAL", } - if err != nil { - return nil, err - } - - return db, nil -} - -//go:embed migrations/*.sql -var migrationFS embed.FS - -// TODO: merge into open? -func InitDB(ctx context.Context, db *sqlx.DB) error { - // Make sure that we have a table for the migrations - - tx, err := db.BeginTxx(ctx, nil) - if err != nil { - return fmt.Errorf("creating tx: %w", err) - } - - defer func() { + for _, p := range pragmas { + _, err = db.ExecContext(ctx, p) if err != nil { - rerr := tx.Rollback() - if rerr != nil { - level.Error(log.GetContext(ctx)). - Log("error", rerr, - "msg", "can't roll back transaction") - } - - return + return nil, fmt.Errorf("running %q: %w", p, err) } - - err = tx.Commit() - }() - - _, err = tx.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS migrations (name string)`) - if err != nil { - return fmt.Errorf("creating migrations table: %w", err) } - _, err = tx.ExecContext(ctx, `CREATE UNIQUE INDEX IF NOT EXISTS migrations_name ON migrations (name)`) + // Make sure that we have a table for the migrations, the rest runs in a transaction. + _, err = db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS migrations (name text, unique(name))`) if err != nil { - return fmt.Errorf("creating migrations index: %w", err) + return nil, fmt.Errorf("creating migrations table: %w", err) } + q := query.New(db) + entries, err := migrationFS.ReadDir("migrations") if err != nil { - return fmt.Errorf("reading embedded migration FS: %w", err) + return nil, fmt.Errorf("reading embedded migration FS: %w", err) } sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() }) - for _, e := range entries { - // Check if we need to apply that migration - var count int - err = tx.GetContext(ctx, &count, `SELECT count(*) FROM migrations WHERE name = ?1`, e.Name()) - if err != nil { - return fmt.Errorf("checking for migration %q: %w", e.Name(), err) - } + err = q.RunTx(ctx, func(q *query.Queries) error { + for _, e := range entries { + // Check if we need to apply that migration + count, err := q.CountMigrations(ctx, e.Name()) + if err != nil { + return fmt.Errorf("checking for migration %q: %w", e.Name(), err) + } - if count > 1 { - return fmt.Errorf("migration %q applied more than once", e.Name()) - } + if count > 1 { + return fmt.Errorf("migration %q applied more than once", e.Name()) + } - if count == 1 { - continue - } + if count == 1 { + continue + } - content, err := fs.ReadFile(migrationFS, "migrations/"+e.Name()) - if err != nil { - return fmt.Errorf("reading migration %q: %w", e.Name(), err) - } + content, err := fs.ReadFile(migrationFS, "migrations/"+e.Name()) + if err != nil { + return fmt.Errorf("reading migration %q: %w", e.Name(), err) + } - _, err = tx.ExecContext(ctx, string(content)) - if err != nil { - return fmt.Errorf("applying migration %q: %w", e.Name(), err) - } + _, err = q.Raw().ExecContext(ctx, string(content)) + if err != nil { + return fmt.Errorf("applying migration %q: %w", e.Name(), err) + } - // Mark migration as applied - _, err = tx.ExecContext(ctx, `INSERT INTO migrations (name) VALUES (?1)`, e.Name()) - if err != nil { - return fmt.Errorf("marking migration %q as applied: %w", e.Name(), err) + // Mark migration as applied + err = q.AddMigration(ctx, e.Name()) + if err != nil { + return fmt.Errorf("marking migration %q as applied: %w", e.Name(), err) + } } + + return nil + }) + if err != nil { + return nil, fmt.Errorf("applying migrations: %w", err) } - return nil + return db, nil }