Skip to content
Snippets Groups Projects
Commit 7ea5648a authored by gbe's avatar gbe
Browse files

Get started with linting and such

parent f300b296
No related branches found
No related tags found
No related merge requests found
linters:
enable-all: true
disable:
- testpackage
issues:
exclude-rules:
- linters:
- errcheck
source: "Log\\("
- linters:
- gomnd
source: "\\*time.(Second|Minute|Hour)"
\ No newline at end of file
......@@ -53,6 +53,7 @@ func Require(next http.Handler, authFailed http.Handler, provider Provider) http
Log("msg", "denying POST access for unknown user")
authFailed.ServeHTTP(w, r)
return
}
......
......@@ -24,7 +24,7 @@ import (
func (h Handler) img() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
if r.Method != http.MethodGet {
httpError(w, r, "bad method", errors.New(r.Method), http.StatusMethodNotAllowed)
return
}
......@@ -55,6 +55,13 @@ func (h Handler) img() http.HandlerFunc {
}
}
const (
actionDelete = "delete"
actionSave = "save"
)
const maxCommentLen = 1024
func (h Handler) detailsPost(w http.ResponseWriter, r *http.Request) {
action := r.FormValue("action")
if strings.HasPrefix(action, "delete-comment-") {
......@@ -79,14 +86,16 @@ func (h Handler) detailsPost(w http.ResponseWriter, r *http.Request) {
Log("id", commentID, "msg", "deleted comment")
http.Redirect(w, r, "/details?id="+r.FormValue("id"), http.StatusSeeOther)
return
}
switch action {
case "save", "delete":
case actionSave, actionDelete:
default:
err := fmt.Errorf("action missing or unknown: %q", action)
httpError(w, r, "invalid parameters", err, http.StatusBadRequest)
return
}
......@@ -96,7 +105,7 @@ func (h Handler) detailsPost(w http.ResponseWriter, r *http.Request) {
return
}
if action == "delete" {
if action == actionDelete {
err = h.Q.DeleteWine(r.Context(), int32(id))
if err != nil {
httpError(w, r, "can't delete wine", err, http.StatusInternalServerError)
......@@ -104,6 +113,7 @@ func (h Handler) detailsPost(w http.ResponseWriter, r *http.Request) {
}
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
......@@ -114,6 +124,7 @@ func (h Handler) detailsPost(w http.ResponseWriter, r *http.Request) {
}
rawRating := r.FormValue("rating")
rating, err := strconv.Atoi(rawRating)
if rawRating != "" && err != nil {
httpError(w, r, "can't parse rating", nil, http.StatusBadRequest)
......@@ -127,11 +138,11 @@ func (h Handler) detailsPost(w http.ResponseWriter, r *http.Request) {
}
comment := strings.TrimSpace(r.FormValue("comment"))
if len(comment) > 1024 {
comment = comment[:1024] + "..."
if len(comment) > maxCommentLen {
comment = comment[:maxCommentLen] + "..."
}
err = r.ParseMultipartForm(16 * 1024 * 1024)
err = r.ParseMultipartForm(maxFormFileSize)
if err != nil {
httpError(w, r, "can't parse multipart form data", err, http.StatusInternalServerError)
return
......@@ -213,12 +224,12 @@ func (h Handler) details() http.Handler {
tpl = template.Must(template.ParseFS(templateFS, "templates/base.tpl", "templates/details.tpl"))
})
if r.Method == "POST" {
if r.Method == http.MethodPost {
h.detailsPost(w, r)
return
}
if r.Method != "GET" {
if r.Method != http.MethodGet {
httpError(w, r, "bad method", errors.New(r.Method), http.StatusMethodNotAllowed)
return
}
......
......@@ -86,7 +86,7 @@ func (h Handler) index() http.Handler {
}
}
err = r.ParseMultipartForm(16 * 1024 * 1024)
err = r.ParseMultipartForm(maxFormFileSize)
if err != nil {
level.Warn(log.Get(r)).Log("error", err, "msg", "can't parse multipart form data")
}
......@@ -142,6 +142,6 @@ func (h Handler) index() http.Handler {
return
}
http.Redirect(w, r, "/details?id="+strconv.Itoa(int(id)), http.StatusSeeOther) // TODO: Is this the correct status?
http.Redirect(w, r, "/details?id="+strconv.Itoa(int(id)), http.StatusFound)
})
}
......@@ -12,13 +12,15 @@ import (
"git.c3pb.de/gbe/invinoveritas/session"
)
func (h Handler) userCreate(w http.ResponseWriter, r *http.Request) (string, error) {
const maxUserNameLen = 80
func (h Handler) userCreate(r *http.Request) (string, error) {
name := r.PostForm.Get("new-user")
if name == "" {
return "", errors.New("name can't be empty")
}
if len(name) > 80 {
if len(name) > maxUserNameLen {
return "", fmt.Errorf("user name too long %q", name)
}
......@@ -134,7 +136,7 @@ func (h Handler) user(page pageName) http.Handler {
case "update":
data.ShowResult = true
case "create":
generatedPw, err := h.userCreate(w, r)
generatedPw, err := h.userCreate(r)
if err != nil {
data.FormErrors["NewUser"] = err
}
......
......@@ -40,7 +40,7 @@ func Request(next http.Handler, logger log.Logger) http.Handler {
defer func() {
d := time.Since(start)
level.Info(l).
_ = level.Info(l).
Log("duration", d, "msg", "request handled")
}()
......@@ -51,7 +51,8 @@ func Request(next http.Handler, logger log.Logger) http.Handler {
})
}
// Get returns the logger for the given HTTP request. It will return a logger that discards all writes if r does not have a logger in its context.
// Get returns the logger for the given HTTP request. It will return a logger that discards all writes
// if r does not have a logger in its context.
func Get(r *http.Request) log.Logger {
return GetContext(r.Context())
}
......
......@@ -57,6 +57,8 @@ func TestLogRequest(t *testing.T) {
}
for _, tc := range testCases {
tc := tc
t.Run(tc.forwardedFor+"-"+tc.remoteAddr, func(t *testing.T) {
var log testLogger
hdlr := Request(next, &log)
......
......@@ -23,15 +23,19 @@ import (
)
//go:embed templates/*.tpl
var templateFS embed.FS
var templateFS embed.FS //nolint:gochecknoglobals
//go:embed static/*
var staticFS embed.FS
var staticFS embed.FS //nolint:gochecknoglobals
// Build info. Will be set by Gitlab pipeline.
var (
commitHash string
buildTime string
commitHash string //nolint:gochecknoglobals
buildTime string //nolint:gochecknoglobals
)
const (
maxFormFileSize = 16 * 1024 * 1024
)
func httpError(w http.ResponseWriter, r *http.Request, msg string, err error, status int) {
......@@ -78,6 +82,7 @@ func addCacheHeaders(next http.Handler) http.Handler {
// - log.Request to make a logger available in the request context
// - addCacheHeader for caching
// - auth.Require
//nolint:godot
func wrapMiddleware(hdlr http.Handler, sessions session.Provider, logger kitlog.Logger) http.Handler {
authFailed := log.Request(sessions.Handler(templateFS), logger)
return log.Request(auth.Require(hdlr, authFailed, sessions), logger)
......@@ -153,6 +158,7 @@ func main() {
})
var g run.Group
g.Add(run.SignalHandler(ctx, os.Interrupt))
{
......
......@@ -177,8 +177,6 @@ func (a Provider) Handler(templateFS fs.FS) http.Handler {
level.Debug(log).
Log("num_users", count, "msg", "got user count")
// TODO: Deal with this. Report an error if != 0 users in db, otherwise allow them through.
w.WriteHeader(http.StatusUnauthorized)
rd := responseData{
......@@ -300,18 +298,19 @@ func (a Provider) ListSessions(ctx context.Context) ([]Info, error) {
return nil, err
}
var sessions []Info
for _, r := range rows {
sessions := make([]Info, len(rows))
for i, r := range rows {
t, err := time.Parse(time.RFC3339Nano, r.Created)
if err != nil {
return nil, err
}
sessions = append(sessions, Info{
sessions[i] = Info{
Name: r.Name,
Created: t,
Remote: r.Remote,
})
}
}
return sessions, nil
......@@ -358,16 +357,18 @@ func (a Provider) UpdatePassword(ctx context.Context, userName, passOld, passNew
return auth.ErrAuthFailed
}
level.Info(log.GetContext(ctx)).
Log("name", userName, "msg", "updating password")
level.Info(log.GetContext(ctx)).Log("name", userName, "msg", "updating password")
// Either user doesn't exist yet (no entry) or the password they gave as their old matched. We
// can update.
// Try an update first, if that fails, an insert
// Either user doesn't exist yet (no entry) or the password they gave as their old matched. We
// can update.
// Try an update first, if that fails, an insert
res, err := q.UpdateUser(ctx, query.UpdateUserParams{
Name: userName,
Password: hashedNew,
})
if err != nil {
return err
}
n, err := res.RowsAffected()
if err != nil {
......
......@@ -3,6 +3,7 @@ package query
import (
"context"
"database/sql"
"fmt"
)
type TXFunc func(ctx context.Context, q *Queries) error
......@@ -27,7 +28,11 @@ func (q *Queries) RunTx(ctx context.Context, txf TXFunc) (err error) {
defer func() {
if err != nil {
tx.Rollback()
rErr := tx.Rollback()
if rErr != nil {
err = fmt.Errorf("rollback error %s while handling %w", rErr, err)
}
return
}
......
......@@ -5,18 +5,12 @@ import (
"context"
"database/sql"
"embed"
"encoding/csv"
"errors"
"fmt"
"image"
"image/png"
"io"
"io/fs"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"git.c3pb.de/gbe/invinoveritas/log"
"git.c3pb.de/gbe/invinoveritas/storage/query"
......@@ -62,6 +56,7 @@ func AddPicture(ctx context.Context, q *query.Queries, wineID int, fh io.Reader,
draw.ApproxBiLinear.Scale(scaled, rect, img, bounds, draw.Over, nil)
var buf bytes.Buffer
err = png.Encode(&buf, scaled)
if err != nil {
return fmt.Errorf("encoding image: %w", err)
......@@ -79,7 +74,7 @@ func AddPicture(ctx context.Context, q *query.Queries, wineID int, fh io.Reader,
}
//go:embed migrations/*.sql
var migrationFS embed.FS
var migrationFS embed.FS //nolint:gochecknoglobals
func Open(ctx context.Context, dbPath, dumpPath string, logger kitlog.Logger) (*sql.DB, error) {
db, err := sql.Open("sqlite", dbPath)
......@@ -157,179 +152,5 @@ func Open(ctx context.Context, dbPath, dumpPath string, logger kitlog.Logger) (*
return nil, fmt.Errorf("applying migrations: %w", err)
}
err = q.RunTx(ctx, func(ctx context.Context, q *query.Queries) error {
v, err := q.GetState(ctx, "migrated-from-ql")
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
if v == "true" {
level.Info(logger).Log("msg", "dump already imported")
return nil
}
level.Info(logger).Log("msg", "attempting import of dumped db", "path", dumpPath)
winesFH, err := os.Open(filepath.Join(dumpPath, "wines.csv"))
if err != nil {
level.Warn(logger).Log("msg", "can't open wines dump, aborting import", "error", err)
return nil
}
defer winesFH.Close()
commentsFH, err := os.Open(filepath.Join(dumpPath, "comments.csv"))
if err != nil {
level.Warn(logger).Log("msg", "can't open comments dump, aborting import", "error", err)
return nil
}
defer commentsFH.Close()
// Import wines first
r := csv.NewReader(winesFH)
r.TrimLeadingSpace = true
wines := make(map[int]int64) // Maps old ID to new ID
for {
record, err := r.Read()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return err
}
if len(record) != 5 {
return fmt.Errorf("unexpected record length %d", len(record))
}
rawID, name, rawRating := record[0], record[1], record[2]
id, err := strconv.Atoi(rawID)
if err != nil {
return err
}
country := record[3]
img, err := decodeImg(record[4])
if err != nil {
return err
}
rating, err := strconv.Atoi(rawRating)
if rawRating != "" && err != nil {
return err
}
level.Debug(logger).Log("id", id, "name", name, "rating", rating, "country", country, "img_bytes", len(img))
res, err := q.InsertWine(ctx, query.InsertWineParams{
Name: name,
Rating: sql.NullInt32{
Int32: int32(rating),
Valid: rawRating != "",
},
Country: sql.NullString{
String: country,
Valid: country != "",
},
})
if err != nil {
return err
}
newID, err := res.LastInsertId()
if err != nil {
return err
}
wines[id] = newID
err = q.StorePicture(ctx, query.StorePictureParams{
WineID: int32(newID),
Picture: img,
})
if err != nil {
return err
}
}
// And then the comments
r = csv.NewReader(commentsFH)
r.TrimLeadingSpace = true
for {
record, err := r.Read()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return err
}
if len(record) != 3 {
return fmt.Errorf("unexpected record length %d", len(record))
}
rawWineID, comment := record[1], record[2]
wineID, err := strconv.Atoi(rawWineID)
if err != nil {
return err
}
level.Debug(logger).Log("wine_id", wines[wineID], "comment", comment)
err = q.AddComment(ctx, query.AddCommentParams{
WineID: int32(wines[wineID]),
Comment: comment,
})
if err != nil {
return err
}
}
err = q.SetState(ctx, query.SetStateParams{
Key: "migrated-from-ql",
Val: "true",
})
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, fmt.Errorf("importing dumped DB: %w", err)
}
return db, nil
}
func decodeImg(enc string) ([]byte, error) {
if enc == "<nil>" {
return nil, nil
}
if len(enc) < 3 {
return nil, errors.New("short encoding")
}
if enc[0] != '[' || enc[len(enc)-1] != ']' {
return nil, errors.New("malformed header/trailer")
}
parts := strings.Split(enc[1:len(enc)-1], " ")
decoded := make([]byte, len(parts))
for i, p := range parts {
b, err := strconv.Atoi(p)
if err != nil {
return nil, err
}
if b < 0 || b > 255 {
return nil, fmt.Errorf("out of range: %d", b)
}
decoded[i] = byte(b)
}
return decoded, nil
}
......@@ -13,14 +13,14 @@ import (
type ISO2CountryCode [2]byte
var UnknownCountry = ISO2CountryCode{'X', 'X'}
var UnknownCountry = ISO2CountryCode{'X', 'X'} //nolint:gochecknoglobals
func ISO2CountryCodeFromString(s string) (ISO2CountryCode, error) {
if len(s) == 0 {
return UnknownCountry, nil
}
if len(s) != 2 {
if len(s) != len(UnknownCountry) {
return UnknownCountry, errors.New("invalid length")
}
......@@ -47,7 +47,7 @@ func (i *ISO2CountryCode) UnmarshalBinary(d []byte) error {
return nil
}
if len(d) != 2 {
if len(d) != len(UnknownCountry) {
*i = UnknownCountry
return nil
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment