From d3551a235964a0de1d286dd17a097327b8973a09 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Wed, 25 May 2022 20:33:13 +0800 Subject: [PATCH] fix insert duplicate users/items/feedback (#465) --- storage/data/database.go | 13 +- storage/data/database_test.go | 15 ++ storage/data/sql.go | 313 ++++++++-------------------------- 3 files changed, 92 insertions(+), 249 deletions(-) diff --git a/storage/data/database.go b/storage/data/database.go index ff0734aab..c88b9950a 100644 --- a/storage/data/database.go +++ b/storage/data/database.go @@ -35,17 +35,16 @@ import ( var ( ErrUserNotExist = errors.NotFoundf("user") ErrItemNotExist = errors.NotFoundf("item") - ErrUnsupported = errors.NotSupportedf("interface") ErrNoDatabase = errors.NotAssignedf("database") ) // Item stores meta data about item. type Item struct { - ItemId string + ItemId string `gorm:"primaryKey"` IsHidden bool - Categories []string `gorm:"-"` + Categories []string `gorm:"serializer:json"` Timestamp time.Time - Labels []string `gorm:"-"` + Labels []string `gorm:"serializer:json"` Comment string } @@ -64,9 +63,9 @@ type ItemPatch struct { // User stores meta data about user. type User struct { - UserId string - Labels []string `gorm:"-"` - Subscribe []string `gorm:"-"` + UserId string `gorm:"primaryKey"` + Labels []string `gorm:"serializer:json"` + Subscribe []string `gorm:"serializer:json"` Comment string } diff --git a/storage/data/database_test.go b/storage/data/database_test.go index a89932dd2..18f55c9ab 100644 --- a/storage/data/database_test.go +++ b/storage/data/database_test.go @@ -180,6 +180,10 @@ func testUsers(t *testing.T, db Database) { // test insert empty err = db.BatchInsertUsers(nil) assert.NoError(t, err) + + // insert duplicate users + err = db.BatchInsertUsers([]User{{UserId: "1"}, {UserId: "1"}}) + assert.NoError(t, err) } func testFeedback(t *testing.T, db Database) { @@ -333,6 +337,13 @@ func testFeedback(t *testing.T, db Database) { {FeedbackKey: FeedbackKey{"a", "100", "200"}}, }, false, false, false) assert.NoError(t, err) + + // insert duplicate feedback + err = db.BatchInsertFeedback([]Feedback{ + {FeedbackKey: FeedbackKey{"a", "0", "0"}}, + {FeedbackKey: FeedbackKey{"a", "0", "0"}}, + }, true, true, true) + assert.NoError(t, err) } func testItems(t *testing.T, db Database) { @@ -435,6 +446,10 @@ func testItems(t *testing.T, db Database) { items, err = db.BatchGetItems(nil) assert.NoError(t, err) assert.Empty(t, items) + + // test insert duplicate items + err = db.BatchInsertItems([]Item{{ItemId: "1"}, {ItemId: "1"}}) + assert.NoError(t, err) } func testDeleteUser(t *testing.T, db Database) { diff --git a/storage/data/sql.go b/storage/data/sql.go index cab4993d6..3f6058447 100644 --- a/storage/data/sql.go +++ b/storage/data/sql.go @@ -21,6 +21,7 @@ import ( "github.com/juju/errors" _ "github.com/lib/pq" _ "github.com/mailru/go-clickhouse" + "github.com/samber/lo" "github.com/scylladb/go-set/strset" "github.com/zhenghaoz/gorse/base/json" "github.com/zhenghaoz/gorse/base/log" @@ -200,8 +201,18 @@ func (d *SQLDatabase) BatchInsertItems(items []Item) error { case ClickHouse: builder.WriteString("INSERT INTO items(item_id, is_hidden, categories, time_stamp, labels, comment, version) VALUES ") } + memo := strset.New() var args []interface{} - for i, item := range items { + for _, item := range items { + // remove duplicate items + if memo.Has(item.ItemId) { + continue + } else { + memo.Add(item.ItemId) + } + if len(args) > 0 { + builder.WriteString(",") + } labels, err := json.Marshal(item.Labels) if err != nil { return errors.Trace(err) @@ -218,9 +229,6 @@ func (d *SQLDatabase) BatchInsertItems(items []Item) error { case ClickHouse: builder.WriteString("(?,?,?,?,?,?,NOW())") } - if i+1 < len(items) { - builder.WriteString(",") - } if d.driver == ClickHouse { args = append(args, item.ItemId, item.IsHidden, string(categories), item.Timestamp.In(time.UTC), string(labels), item.Comment) } else { @@ -310,142 +318,39 @@ func (d *SQLDatabase) ModifyItem(itemId string, patch ItemPatch) error { log.Logger().Debug("empty item patch") return nil } - var builder strings.Builder - var args []interface{} - delimiter := " " - switch d.driver { - case MySQL: - builder.WriteString("UPDATE items SET") - if patch.IsHidden != nil { - builder.WriteString(delimiter) - builder.WriteString("is_hidden = ?") - args = append(args, patch.IsHidden) - delimiter = ", " - } - if patch.Categories != nil { - builder.WriteString(delimiter) - text, _ := json.Marshal(patch.Categories) - builder.WriteString("`categories` = ?") - args = append(args, text) - delimiter = ", " - } - if patch.Comment != nil { - builder.WriteString(delimiter) - builder.WriteString("`comment` = ?") - args = append(args, patch.Comment) - delimiter = ", " - } - if patch.Labels != nil { - builder.WriteString(delimiter) - text, _ := json.Marshal(patch.Labels) - builder.WriteString("`labels` = ?") - args = append(args, text) - delimiter = ", " - } - if patch.Timestamp != nil { - builder.WriteString(delimiter) - builder.WriteString("time_stamp = ?") - args = append(args, patch.Timestamp) - } - builder.WriteString(" WHERE item_id = ?") - args = append(args, itemId) - case Postgres: - builder.WriteString("UPDATE items SET") - if patch.IsHidden != nil { - builder.WriteString(delimiter) - builder.WriteString(fmt.Sprintf("is_hidden = $%d", len(args)+1)) - args = append(args, patch.IsHidden) - delimiter = ", " - } - if patch.Categories != nil { - builder.WriteString(delimiter) - text, _ := json.Marshal(patch.Categories) - builder.WriteString(fmt.Sprintf("categories = $%d", len(args)+1)) - args = append(args, text) - delimiter = ", " - } - if patch.Comment != nil { - builder.WriteString(delimiter) - builder.WriteString(fmt.Sprintf("comment = $%d", len(args)+1)) - args = append(args, patch.Comment) - delimiter = ", " - } - if patch.Labels != nil { - builder.WriteString(delimiter) - text, _ := json.Marshal(patch.Labels) - builder.WriteString(fmt.Sprintf("labels = $%d", len(args)+1)) - args = append(args, text) - delimiter = ", " - } - if patch.Timestamp != nil { - builder.WriteString(delimiter) - builder.WriteString(fmt.Sprintf("time_stamp = $%d", len(args)+1)) - args = append(args, patch.Timestamp) - } - builder.WriteString(fmt.Sprintf(" WHERE item_id = $%d", len(args)+1)) - args = append(args, itemId) - case ClickHouse: - builder.WriteString("ALTER TABLE items UPDATE") - if patch.IsHidden != nil { - builder.WriteString(delimiter) - builder.WriteString("is_hidden = ?") - args = append(args, patch.IsHidden) - delimiter = ", " - } - if patch.Categories != nil { - builder.WriteString(delimiter) - text, _ := json.Marshal(patch.Categories) - builder.WriteString("`categories` = ?") - args = append(args, string(text)) - delimiter = ", " - } - if patch.Comment != nil { - builder.WriteString(delimiter) - builder.WriteString("`comment` = ?") - args = append(args, patch.Comment) - delimiter = ", " - } - if patch.Labels != nil { - builder.WriteString(delimiter) - text, _ := json.Marshal(patch.Labels) - builder.WriteString("`labels` = ?") - args = append(args, string(text)) - delimiter = ", " - } - if patch.Timestamp != nil { - builder.WriteString(delimiter) - builder.WriteString("time_stamp = ?") - args = append(args, patch.Timestamp.In(time.UTC)) - } - builder.WriteString(" WHERE item_id = ?") - args = append(args, itemId) + attributes := make(map[string]any) + if patch.IsHidden != nil { + attributes["is_hidden"] = *patch.IsHidden } - _, err := d.client.Exec(builder.String(), args...) + if patch.Categories != nil { + text, _ := json.Marshal(patch.Categories) + attributes["categories"] = string(text) + } + if patch.Comment != nil { + attributes["comment"] = *patch.Comment + } + if patch.Labels != nil { + text, _ := json.Marshal(patch.Labels) + attributes["labels"] = string(text) + } + if patch.Timestamp != nil { + if d.driver == ClickHouse { + attributes["time_stamp"] = patch.Timestamp.In(time.UTC) + } else { + attributes["time_stamp"] = patch.Timestamp + } + } + err := d.gormDB.Model(&Item{ItemId: itemId}).Updates(attributes).Error return errors.Trace(err) } // GetItems returns items from MySQL. func (d *SQLDatabase) GetItems(cursor string, n int, timeLimit *time.Time) (string, []Item, error) { - var result *sql.Rows - var err error - switch d.driver { - case MySQL, ClickHouse: - if timeLimit == nil { - result, err = d.client.Query("SELECT item_id, is_hidden, categories, time_stamp, labels, `comment` FROM items "+ - "WHERE item_id >= ? ORDER BY item_id LIMIT ?", cursor, n+1) - } else { - result, err = d.client.Query("SELECT item_id, is_hidden, categories, time_stamp, labels, `comment` FROM items "+ - "WHERE item_id >= ? AND time_stamp >= ? ORDER BY item_id LIMIT ?", cursor, *timeLimit, n+1) - } - case Postgres: - if timeLimit == nil { - result, err = d.client.Query("SELECT item_id, is_hidden, categories, time_stamp, labels, comment FROM items "+ - "WHERE item_id >= $1 ORDER BY item_id LIMIT $2", cursor, n+1) - } else { - result, err = d.client.Query("SELECT item_id, is_hidden, categories, time_stamp, labels, comment FROM items "+ - "WHERE item_id >= $1 AND time_stamp >= $2 ORDER BY item_id LIMIT $3", cursor, *timeLimit, n+1) - } + tx := d.gormDB.Table("items").Select("item_id, is_hidden, categories, time_stamp, labels, comment").Where("item_id >= ?", cursor) + if timeLimit != nil { + tx.Where("time_stamp >= ?", *timeLimit) } + result, err := tx.Order("item_id").Limit(n + 1).Rows() if err != nil { return "", nil, errors.Trace(err) } @@ -556,8 +461,18 @@ func (d *SQLDatabase) BatchInsertUsers(users []User) error { case ClickHouse: builder.WriteString("INSERT INTO users(user_id, labels, subscribe, comment, version) VALUES ") } + memo := strset.New() var args []interface{} - for i, user := range users { + for _, user := range users { + // remove duplicate users + if memo.Has(user.UserId) { + continue + } else { + memo.Add(user.UserId) + } + if len(args) > 0 { + builder.WriteString(",") + } labels, err := json.Marshal(user.Labels) if err != nil { return errors.Trace(err) @@ -574,9 +489,6 @@ func (d *SQLDatabase) BatchInsertUsers(users []User) error { case ClickHouse: builder.WriteString("(?,?,?,?,NOW())") } - if i+1 < len(users) { - builder.WriteString(",") - } args = append(args, user.UserId, string(labels), string(subscribe), user.Comment) } switch d.driver { @@ -636,78 +548,21 @@ func (d *SQLDatabase) ModifyUser(userId string, patch UserPatch) error { log.Logger().Debug("empty user patch") return nil } - var builder strings.Builder - var args []interface{} - delimiter := " " - switch d.driver { - case MySQL: - builder.WriteString("UPDATE users SET") - if patch.Comment != nil { - builder.WriteString(delimiter) - builder.WriteString("`comment` = ?") - args = append(args, patch.Comment) - delimiter = ", " - } - if patch.Labels != nil { - builder.WriteString(delimiter) - text, _ := json.Marshal(patch.Labels) - builder.WriteString("`labels` = ?") - args = append(args, text) - } - builder.WriteString(" WHERE user_id = ?") - args = append(args, userId) - case Postgres: - builder.WriteString("UPDATE users SET") - if patch.Comment != nil { - builder.WriteString(delimiter) - builder.WriteString(fmt.Sprintf("comment = $%d", len(args)+1)) - args = append(args, patch.Comment) - delimiter = ", " - } - if patch.Labels != nil { - builder.WriteString(delimiter) - text, _ := json.Marshal(patch.Labels) - builder.WriteString(fmt.Sprintf("labels = $%d", len(args)+1)) - args = append(args, text) - } - builder.WriteString(fmt.Sprintf(" WHERE user_id = $%d", len(args)+1)) - args = append(args, userId) - case ClickHouse: - builder.WriteString("ALTER TABLE users UPDATE") - if patch.Comment != nil { - builder.WriteString(delimiter) - builder.WriteString("`comment` = ?") - args = append(args, patch.Comment) - delimiter = ", " - } - if patch.Labels != nil { - builder.WriteString(delimiter) - text, _ := json.Marshal(patch.Labels) - builder.WriteString("`labels` = ?") - args = append(args, string(text)) - } - builder.WriteString(" WHERE user_id = ?") - args = append(args, userId) + attributes := make(map[string]any) + if patch.Comment != nil { + attributes["comment"] = *patch.Comment } - _, err := d.client.Exec(builder.String(), args...) + if patch.Labels != nil { + text, _ := json.Marshal(patch.Labels) + attributes["labels"] = string(text) + } + err := d.gormDB.Model(&User{UserId: userId}).Updates(attributes).Error return errors.Trace(err) } // GetUsers returns users from MySQL. func (d *SQLDatabase) GetUsers(cursor string, n int) (string, []User, error) { - var result *sql.Rows - var err error - switch d.driver { - case MySQL: - result, err = d.client.Query("SELECT user_id, labels, subscribe, `comment` FROM users "+ - "WHERE user_id >= ? ORDER BY user_id LIMIT ?", cursor, n+1) - case Postgres: - result, err = d.client.Query("SELECT user_id, labels, subscribe, comment FROM users "+ - "WHERE user_id >= $1 ORDER BY user_id LIMIT $2", cursor, n+1) - case ClickHouse: - result, err = d.client.Query("SELECT user_id, labels, subscribe, `comment` FROM users "+ - "WHERE user_id >= ? ORDER BY user_id LIMIT ?", cursor, n+1) - } + result, err := d.gormDB.Table("users").Select("user_id, labels, subscribe, comment").Where("user_id >= ?", cursor).Order("user_id").Limit(n + 1).Rows() if err != nil { return "", nil, errors.Trace(err) } @@ -941,8 +796,15 @@ func (d *SQLDatabase) BatchInsertFeedback(feedback []Feedback, insertUser, inser builder.WriteString("INSERT INTO feedback(feedback_type, user_id, item_id, time_stamp, comment) VALUES ") } var args []interface{} + memo := make(map[lo.Tuple3[string, string, string]]struct{}) for _, f := range feedback { if users.Has(f.UserId) && items.Has(f.ItemId) { + // remove duplicate feedback + if _, exist := memo[lo.Tuple3[string, string, string]{f.FeedbackType, f.UserId, f.ItemId}]; exist { + continue + } else { + memo[lo.Tuple3[string, string, string]{f.FeedbackType, f.UserId, f.ItemId}] = struct{}{} + } if len(args) > 0 { builder.WriteString(",") } @@ -994,49 +856,16 @@ func (d *SQLDatabase) GetFeedback(cursor string, n int, timeLimit *time.Time, fe return "", nil, err } } - var result *sql.Rows - var err error - var builder strings.Builder - switch d.driver { - case MySQL, ClickHouse: - builder.WriteString("SELECT feedback_type, user_id, item_id, time_stamp, `comment` FROM feedback WHERE time_stamp <= NOW() AND (feedback_type, user_id, item_id) >= (?,?,?)") - case Postgres: - builder.WriteString("SELECT feedback_type, user_id, item_id, time_stamp, comment FROM feedback WHERE time_stamp <= NOW() AND (feedback_type, user_id, item_id) >= ($1,$2,$3)") - } - args := []interface{}{cursorKey.FeedbackType, cursorKey.UserId, cursorKey.ItemId} + tx := d.gormDB.Table("feedback").Select("feedback_type, user_id, item_id, time_stamp, comment"). + Where("time_stamp <= NOW() AND (feedback_type, user_id, item_id) >= (?,?,?)", cursorKey.FeedbackType, cursorKey.UserId, cursorKey.ItemId) if len(feedbackTypes) > 0 { - builder.WriteString(" AND feedback_type IN (") - for i, feedbackType := range feedbackTypes { - switch d.driver { - case MySQL, ClickHouse: - builder.WriteString("?") - case Postgres: - builder.WriteString(fmt.Sprintf("$%d", len(args)+1)) - } - if i+1 < len(feedbackTypes) { - builder.WriteString(",") - } - args = append(args, feedbackType) - } - builder.WriteString(")") + tx.Where("feedback_type IN ?", feedbackTypes) } if timeLimit != nil { - switch d.driver { - case MySQL, ClickHouse: - builder.WriteString(" AND time_stamp >= ?") - case Postgres: - builder.WriteString(fmt.Sprintf(" AND time_stamp >= $%d", len(args)+1)) - } - args = append(args, *timeLimit) + tx.Where("time_stamp >= ?", *timeLimit) } - switch d.driver { - case MySQL, ClickHouse: - builder.WriteString(" ORDER BY feedback_type, user_id, item_id LIMIT ?") - case Postgres: - builder.WriteString(fmt.Sprintf(" ORDER BY feedback_type, user_id, item_id LIMIT $%d", len(args)+1)) - } - args = append(args, n+1) - result, err = d.client.Query(builder.String(), args...) + tx.Order("feedback_type, user_id, item_id").Limit(n + 1) + result, err := tx.Rows() if err != nil { return "", nil, errors.Trace(err) }