Skip to content

Commit

Permalink
Merge pull request #5 from go-gorm/feature/pk
Browse files Browse the repository at this point in the history
  • Loading branch information
huacnlee committed Jan 25, 2022
2 parents 81e8137 + 6736eca commit e15a286
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 109 deletions.
32 changes: 8 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Gorm Sharding 是一个高性能的数据库分表中间件。
- Non-intrusive design. Load the plugin, specify the config, and all done.
- Lighting-fast. No network based middlewares, as fast as Go.
- Multiple database support. PostgreSQL tested, MySQL and SQLite is coming.
- Allows you custom the Primary Key generator ([Longkey](https://github.com/longbridgeapp/longkey), Sequence, Snowflake ...).
- Integrated primary key generator (Snowflake, PostgreSQL Sequence, Custom, ...).

## Sharding process

Expand All @@ -40,28 +40,13 @@ Config the sharding middleware, register the tables which you want to shard. See

```go
db.Use(sharding.Register(sharding.Config{
ShardingKey: "user_id",
ShardingAlgorithm: func(value interface{}) (suffix string, err error) {
if user_id, ok := value.(int64); ok {
return fmt.Sprintf("_%02d", user_id%64), nil
}
return "", errors.New("invalid user_id")
},
PrimaryKeyGenerate: func(tableIdx int64) int64 {
// use LongKey for generate a sequence primary key with table index
return longkey.Next(tableIdx)
}
ShardingKey: "user_id",
NumberOfShards: 64,
PrimaryKeyGenerator: sharding.PKSnowflake,
}, "orders").Register(sharding.Config{
ShardingKey: "user_id",
ShardingAlgorithm: func(value interface{}) (suffix string, err error) {
if user_id, ok := value.(int64); ok {
return fmt.Sprintf("_%02d", user_id%256), nil
}
return "", errors.New("invalid user_id")
},
PrimaryKeyGenerate: func(tableIdx int64) int64 {
return snowflake_node.Generate().Int64()
}
ShardingKey: "user_id",
NumberOfShards: 256,
PrimaryKeyGenerator: sharding.PKSnowflake,
// This case for show up give notifications, audit_logs table use same sharding rule.
}, Notification{}, AuditLog{}))
```
Expand Down Expand Up @@ -107,9 +92,8 @@ When you sharding tables, you need consider how the primary key generate.

Recommend options:

- [LongKey](https://github.com/longbridgeapp/longkey)
- [Database sequence by manully](https://www.postgresql.org/docs/current/sql-createsequence.html)
- [Snowflake](https://github.com/bwmarrin/snowflake)
- [Database sequence by manully](https://www.postgresql.org/docs/current/sql-createsequence.html)

## License

Expand Down
20 changes: 3 additions & 17 deletions examples/order.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package main

import (
"errors"
"fmt"

"github.com/bwmarrin/snowflake"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/sharding"
Expand Down Expand Up @@ -33,22 +31,10 @@ func main() {
)`)
}

node, err := snowflake.NewNode(1)
if err != nil {
panic(err)
}

middleware := sharding.Register(sharding.Config{
ShardingKey: "user_id",
ShardingAlgorithm: func(value interface{}) (suffix string, err error) {
if uid, ok := value.(int64); ok {
return fmt.Sprintf("_%02d", uid%64), nil
}
return "", errors.New("invalid user_id")
},
PrimaryKeyGenerate: func(tableIdx int64) int64 {
return node.Generate().Int64()
},
ShardingKey: "user_id",
NumberOfShards: 64,
PrimaryKeyGenerator: sharding.PKSnowflake,
}, "orders")
db.Use(middleware)

Expand Down
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ go 1.17
require (
github.com/bwmarrin/snowflake v0.3.0
github.com/longbridgeapp/assert v0.1.0
github.com/longbridgeapp/longkey v0.1.0
github.com/longbridgeapp/sqlparser v0.2.0
gorm.io/driver/postgres v1.1.0
gorm.io/gorm v1.21.16
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,6 @@ github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-b
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
github.com/longbridgeapp/assert v0.1.0 h1:KkQlHUJSpuUFkUDjwBJgghFl31+wwSDHTq/WRrvLjko=
github.com/longbridgeapp/assert v0.1.0/go.mod h1:ew3umReliXtk1bBG4weVURxdvR0tsN+rCEfjnA4YfxI=
github.com/longbridgeapp/longkey v0.1.0 h1:FW7I89nQNVYal3n0RBSy1eusQQmtTHNLcPsiXHDrEFM=
github.com/longbridgeapp/longkey v0.1.0/go.mod h1:Wt5u8YLL9HThTU3ecmu+BgXxsq73CQxbWYr5ssEvTuA=
github.com/longbridgeapp/sqlparser v0.2.0 h1:A6gvcqGYWpLrLbD2OoXoiMQsyUc9bg24aah902dXgN8=
github.com/longbridgeapp/sqlparser v0.2.0/go.mod h1:GIHaUq8zvYyHLCLMJJykx1CdM6LHtkUih/QaJXySSx4=
github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ=
Expand Down
136 changes: 122 additions & 14 deletions sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,34 @@ package sharding
import (
"errors"
"fmt"
"hash/crc32"
"strconv"
"strings"
"sync"

"github.com/bwmarrin/snowflake"
"github.com/longbridgeapp/sqlparser"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)

const (
PKSnowflake = iota // Use Snowflake primary key generator
PKPGSequence // Use PostgreSQL sequence primary key generator
PKCustom // Use custom primary key generator
)

var (
ErrMissingShardingKey = errors.New("sharding key or id required, and use operator =")
ErrInvalidID = errors.New("invalid id format")
)

type Sharding struct {
*gorm.DB
ConnPool *ConnPool
configs map[string]Config
querys sync.Map
ConnPool *ConnPool
configs map[string]Config
querys sync.Map
snowflakeNodes []*snowflake.Node
}

// Config specifies the configuration for sharding.
Expand All @@ -33,6 +42,12 @@ type Config struct {
// For example, for a product order table, you may want to split the rows by `user_id`.
ShardingKey string

// NumberOfShards specifies how many tables you want to sharding.
NumberOfShards uint

// tableFormat specifies the sharding table suffix format.
tableFormat string

// ShardingAlgorithm specifies a function to generate the sharding
// table's suffix by the column value.
// For example, this function implements a mod sharding algorithm.
Expand All @@ -47,26 +62,27 @@ type Config struct {

// ShardingAlgorithmByPrimaryKey specifies a function to generate the sharding
// table's suffix by the primary key. Used when no sharding key specified.
// For example, this function use the LongKey library to generate the suffix.
// For example, this function use the Snowflake library to generate the suffix.
//
// func(id int64) (suffix string) {
// return fmt.Sprintf("_%02d", longkey.TableIdx(id))
// return fmt.Sprintf("_%02d", snowflake.ParseInt64(id).Node())
// }
ShardingAlgorithmByPrimaryKey func(id int64) (suffix string)

// PrimaryKeyGenerate specifies a function to generate the primary key.
// PrimaryKeyGenerator specifies the primary key generate algorithm.
// Used only when insert and the record does not contains an id field.
// We recommend you use the
// [LongKey](https://github.com/longbridgeapp/longkey) component,
// it is a distributed primary key generator.
// Options are PKSnowflake, PKPGSequence and PKCustom.
// When use PKCustom, you should also specify PrimaryKeyGeneratorFn.
PrimaryKeyGenerator int

// PrimaryKeyGeneratorFn specifies a function to generate the primary key.
// When use auto-increment like generator, the tableIdx argument could ignored.
//
// For example, this function use the LongKey library to generate the primary key.
// For example, this function use the Snowflake library to generate the primary key.
//
// func(tableIdx int64) int64 {
// return longkey.Next(tableIdx)
// return nodes[tableIdx].Generate().Int64()
// }
PrimaryKeyGenerate func(tableIdx int64) int64
PrimaryKeyGeneratorFn func(tableIdx int64) int64
}

func Register(config Config, tables ...interface{}) *Sharding {
Expand All @@ -87,6 +103,75 @@ func (s *Sharding) Register(config Config, tables ...interface{}) *Sharding {
}
}

for t, c := range s.configs {
if c.NumberOfShards > 1024 && c.PrimaryKeyGenerator == PKSnowflake {
panic("Snowflake NumberOfShards should less than 1024")
}

if c.PrimaryKeyGenerator == PKSnowflake {
c.PrimaryKeyGeneratorFn = func(index int64) int64 {
return s.snowflakeNodes[index].Generate().Int64()
}
} else if c.PrimaryKeyGenerator == PKPGSequence {
c.PrimaryKeyGeneratorFn = func(index int64) int64 {
var id int64
err := s.DB.Raw("SELECT nextval('" + pgSeqName(t) + "')").Scan(&id).Error
if err != nil {
panic(err)
}
return id
}
} else if c.PrimaryKeyGenerator == PKCustom {
if c.PrimaryKeyGeneratorFn == nil {
panic("PrimaryKeyGeneratorFn not configured")
}
} else {
panic("PrimaryKeyGenerator can only be one of PKSnowflake, PKPGSequence and PKCustom")
}

if c.ShardingAlgorithm == nil {
if c.NumberOfShards == 0 {
panic("specify NumberOfShards or ShardingAlgorithm")
}
if c.NumberOfShards < 10 {
c.tableFormat = "_%01d"
} else if c.NumberOfShards < 100 {
c.tableFormat = "_%02d"
} else if c.NumberOfShards < 1000 {
c.tableFormat = "_%03d"
} else if c.NumberOfShards < 10000 {
c.tableFormat = "_%04d"
}
c.ShardingAlgorithm = func(value interface{}) (suffix string, err error) {
id := 0
switch value := value.(type) {
case int:
id = value
case int64:
id = int(value)
case string:
id, err = strconv.Atoi(value)
if err != nil {
id = int(crc32.ChecksumIEEE([]byte(value)))
}
default:
return "", fmt.Errorf("default algorithm only support integer and string column," +
"if you use other type, specify you own ShardingAlgorithm")
}
return fmt.Sprintf(c.tableFormat, id%int(c.NumberOfShards)), nil
}
}

if c.ShardingAlgorithmByPrimaryKey == nil {
if c.PrimaryKeyGenerator == PKSnowflake {
c.ShardingAlgorithmByPrimaryKey = func(id int64) (suffix string) {
return fmt.Sprintf(c.tableFormat, snowflake.ParseInt64(id).Node())
}
}
}
s.configs[t] = c
}

return s
}

Expand All @@ -108,6 +193,25 @@ func (s *Sharding) LastQuery() string {
func (s *Sharding) Initialize(db *gorm.DB) error {
s.DB = db
s.registerConnPool(db)

for t, c := range s.configs {
if c.PrimaryKeyGenerator == PKPGSequence {
err := s.DB.Exec("CREATE SEQUENCE IF NOT EXISTS " + pgSeqName(t)).Error
if err != nil {
return fmt.Errorf("init postgresql sequence error, %w", err)
}
}
}

s.snowflakeNodes = make([]*snowflake.Node, 1024)
for i := int64(0); i < 1024; i++ {
n, err := snowflake.NewNode(i)
if err != nil {
return fmt.Errorf("init snowflake node error, %w", err)
}
s.snowflakeNodes[i] = n
}

return nil
}

Expand Down Expand Up @@ -208,7 +312,7 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery,
if err != nil {
return ftQuery, stQuery, tableName, err
}
id := r.PrimaryKeyGenerate(int64(tblIdx))
id := r.PrimaryKeyGeneratorFn(int64(tblIdx))
insertNames = append(insertNames, &sqlparser.Ident{Name: "id"})
insertValues = append(insertValues, &sqlparser.NumberLit{Value: strconv.FormatInt(id, 10)})
}
Expand Down Expand Up @@ -349,3 +453,7 @@ func getBindValue(value interface{}, args []interface{}) (interface{}, error) {
}
return args[pos-1], nil
}

func pgSeqName(table string) string {
return fmt.Sprintf("gorm_sharding_%s_id_seq", table)
}
Loading

0 comments on commit e15a286

Please sign in to comment.