Skip to content

Commit

Permalink
Split primary_key gen as single file
Browse files Browse the repository at this point in the history
  • Loading branch information
huacnlee committed Jan 28, 2022
1 parent 63aa7f8 commit 7da654f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 19 deletions.
29 changes: 29 additions & 0 deletions primary_key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package sharding

import "fmt"

const (
// Use Snowflake primary key generator
PKSnowflake = iota
// Use PostgreSQL sequence primary key generator
PKPGSequence
// Use custom primary key generator
PKCustom
)

func (s *Sharding) genSnowflakeKey(index int64) int64 {
return s.snowflakeNodes[index].Generate().Int64()
}

func (s *Sharding) genPostgreSQLSequenceKey(tableName string, index int64) int64 {
var id int64
err := s.DB.Raw("SELECT nextval('" + pgSeqName(tableName) + "')").Scan(&id).Error
if err != nil {
panic(err)
}
return id
}

func pgSeqName(table string) string {
return fmt.Sprintf("gorm_sharding_%s_id_seq", table)
}
11 changes: 11 additions & 0 deletions primary_key_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package sharding

import (
"testing"

"github.com/longbridgeapp/assert"
)

func Test_pgSeqName(t *testing.T) {
assert.Equal(t, "gorm_sharding_users_id_seq", pgSeqName("users"))
}
21 changes: 2 additions & 19 deletions sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@ import (
"gorm.io/gorm/schema"
)

const (
PKSnowflake = iota // Use Snowflake primary key generator
PKPGSequence // Use PostgreSQL sequence primary key generator
PKCustom // Use custom primary key generator
)

var (
ErrMissingShardingKey = errors.New("sharding key or id required, and use operator =")
ErrInvalidID = errors.New("invalid id format")
Expand Down Expand Up @@ -109,17 +103,10 @@ func (s *Sharding) Register(config Config, tables ...interface{}) *Sharding {
}

if c.PrimaryKeyGenerator == PKSnowflake {
c.PrimaryKeyGeneratorFn = func(index int64) int64 {
return s.snowflakeNodes[index].Generate().Int64()
}
c.PrimaryKeyGeneratorFn = s.genSnowflakeKey
} else if c.PrimaryKeyGenerator == PKPGSequence {
c.PrimaryKeyGeneratorFn = func(index int64) int64 {
var id int64
err := s.DB.Raw("SELECT nextval('" + pgSeqName(t) + "')").Scan(&id).Error
if err != nil {
panic(err)
}
return id
return s.genPostgreSQLSequenceKey(t, index)
}
} else if c.PrimaryKeyGenerator == PKCustom {
if c.PrimaryKeyGeneratorFn == nil {
Expand Down Expand Up @@ -435,7 +422,3 @@ func replaceOrderByTableName(orderBy []*sqlparser.OrderingTerm, oldName, newName

return orderBy
}

func pgSeqName(table string) string {
return fmt.Sprintf("gorm_sharding_%s_id_seq", table)
}

0 comments on commit 7da654f

Please sign in to comment.