package main import ( "bytes" "context" "database/sql" "database/sql/driver" "errors" "fmt" "image" "image/png" "io" "sort" // Imported for side effects to register format handlers _ "image/jpeg" _ "image/png" "github.com/Masterminds/squirrel" "github.com/jmoiron/sqlx" log "github.com/sirupsen/logrus" "golang.org/x/image/draw" ) var errNotFound = errors.New("not found") type Comment struct { ID int `db:"id"` Content string `db:"content"` } type ISO2CountryCode [2]byte var UnknownCountry = ISO2CountryCode{'X', 'X'} func ISO2CountryCodeFromString(s string) (ISO2CountryCode, error) { if len(s) == 0 { return UnknownCountry, nil } if len(s) != 2 { return UnknownCountry, errors.New("invalid length") } return ISO2CountryCode{s[0], s[1]}, nil } func (i ISO2CountryCode) String() string { if i[0] == 0 || i[1] == 0 { return "XX" } return fmt.Sprintf("%c%c", i[0], i[1]) } func (i *ISO2CountryCode) UnmarshalBinary(d []byte) error { if len(d) == 0 { *i = UnknownCountry return nil } if len(d) != 2 { *i = UnknownCountry return nil } copy(i[:], d) return nil } func (i ISO2CountryCode) Value() (driver.Value, error) { return i.String(), nil } func (i *ISO2CountryCode) Scan(data interface{}) error { var raw string switch data := data.(type) { case string: raw = data case []byte: raw = string(data) default: return fmt.Errorf("can't convert from %T", data) } c, err := ISO2CountryCodeFromString(raw) if err != nil { return err } *i = c return nil } type Vino struct { ID int `db:"id"` Name string `db:"name"` Rating int `db:"rating"` Country ISO2CountryCode `db:"country"` Picture image.Image `db:"-"` HasPicture bool `db:"has_picture"` // Set to true if there's picture data for this vino Comments []Comment `db:"-"` } func DeleteVino(ctx context.Context, db *sqlx.DB, id int) (err error) { tx, err := db.Begin() if err != nil { return err } defer func() { if err != nil { log.Println("rolling back transaction") tx.Rollback() return } tx.Commit() }() _, err = tx.ExecContext(ctx, `DELETE FROM comments WHERE wine = ?1`, id) if err != nil { return fmt.Errorf("deleting comments for %d: %w", id, err) } _, err = tx.ExecContext(ctx, `DELETE FROM wines WHERE id() = ?1`, id) if err != nil { return fmt.Errorf("deleting wine %d: %w", id, err) } return nil } func (v *Vino) loadComments(ctx context.Context, tx *sqlx.Tx) error { err := tx.SelectContext(ctx, &v.Comments, ` SELECT id() as id, content FROM comments WHERE wine = ?1`, v.ID) if err != nil { return err } return nil } // loadVino uses the given read-only bolt transaction to load the data for the wine with the given id. When // there are no wines at all, or there is no wine with the given ID, loadVino returns errNotFound. func loadVino(ctx context.Context, tx *sqlx.Tx, id int) (Vino, error) { var v Vino err := tx.GetContext(ctx, &v, ` SELECT id() as id, name, rating, country, picture IS NOT NULL AS has_picture FROM wines WHERE id() = ?1`, id) if err != nil { return v, err } err = v.loadComments(ctx, tx) if err != nil { return v, err } return v, nil } func LoadVino(ctx context.Context, db *sqlx.DB, id int) (Vino, error) { tx, err := db.Beginx() if err != nil { return Vino{}, err } defer tx.Rollback() // No change in tx intended, it's only used for read consistency v, err := loadVino(ctx, tx, id) if err != nil { return v, err } return v, nil } func ListWines(ctx context.Context, db *sqlx.DB) ([]Vino, error) { var wines []Vino tx, err := db.Beginx() if err != nil { return nil, err } defer tx.Rollback() // No write intended in this tx, it's only for read consistency err = tx.SelectContext(ctx, &wines, ` SELECT id() as id, name, rating, country, picture IS NOT NULL AS has_picture FROM wines`) if err != nil { return nil, err } // Load comments for _, v := range wines { err = v.loadComments(ctx, tx) if err != nil { return nil, err } } sort.Slice(wines, func(i, j int) bool { return wines[i].Rating > wines[j].Rating }) return wines, nil } func LoadPictureData(ctx context.Context, db *sqlx.DB, id int) ([]byte, error) { var data []byte err := db.GetContext(ctx, &data, `SELECT picture FROM wines WHERE id() = ?1`, id) if errors.Is(err, sql.ErrNoRows) { return nil, errNotFound } if err != nil { return nil, err } if len(data) == 0 { log.Println("zero length image data") return nil, errNotFound } return data, nil } // AddPicture loads picture data (PNG or JPEG) from fh and sets v's picture to it. // If something goes wrong during loading, or the image is neither PNG nor JPEG, an error // is returned. If contentType is not the empty string, it is validated to be either // image/png or image/jpeg. func (v *Vino) AddPicture(fh io.Reader, contentType string) error { switch contentType { case "", "image/jpeg", "image/png": default: return fmt.Errorf("unexpected content type for image: %q", contentType) } img, _, err := image.Decode(fh) if err != nil { return err } v.Picture = img return nil } func (v Vino) String() string { return fmt.Sprintf("{Name: %q, Rating: %d}", v.Name, v.Rating) } func (v *Vino) Store(ctx context.Context, db *sqlx.DB) (err error) { // Encode scaled image as PNG, will contain image data if there is any values := map[string]interface{}{ "name": v.Name, "rating": v.Rating, "country": v.Country, } if v.Picture != nil { // Scale image down and encode as PNG // Get aspect ratio of incoming picture bounds := v.Picture.Bounds() aspect := float64(bounds.Max.X) / float64(bounds.Max.Y) const destHeight = 800 rect := image.Rect(0, 0, int(destHeight*aspect), destHeight) log.WithFields(log.Fields{ "bounds": bounds, "aspect": aspect, "rect": rect, }).Info("resizing image") scaled := image.NewRGBA(rect) draw.ApproxBiLinear.Scale(scaled, rect, v.Picture, v.Picture.Bounds(), draw.Over, nil) var img bytes.Buffer err := png.Encode(&img, scaled) if err != nil { return err } values["picture"] = img.Bytes() } var ( query string args []interface{} ) if v.ID != 0 { query, args, err = squirrel.Update("wines"). Where(squirrel.Eq{"id()": v.ID}). PlaceholderFormat(squirrel.Dollar). SetMap(values). ToSql() } else { query, args, err = squirrel.Insert("wines"). PlaceholderFormat(squirrel.Dollar). SetMap(values). ToSql() } if err != nil { return err } tx, err := db.Beginx() if err != nil { return err } defer func() { if err != nil { tx.Rollback() return } tx.Commit() }() res, err := tx.ExecContext(ctx, query, args...) if err != nil { return err } if v.ID == 0 { id, err := res.LastInsertId() if err != nil { return err } v.ID = int(id) } return nil } func (v *Vino) StoreComment(ctx context.Context, db *sqlx.DB, text string) (err error) { tx, err := db.Begin() if err != nil { return err } _, err = tx.ExecContext(ctx, `INSERT INTO comments (wine, content) VALUES (?1, ?2)`, v.ID, text) if err != nil { tx.Rollback() return err } tx.Commit() return nil }