Skip to content

Commit

Permalink
Refactor adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
minhduc140583 committed Jul 14, 2024
1 parent 674e972 commit edaa34c
Show file tree
Hide file tree
Showing 11 changed files with 1,147 additions and 17 deletions.
6 changes: 2 additions & 4 deletions adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ type Adapter[T any, K any] struct {
*Writer[*T]
Map map[string]int
Fields string
Keys []string
IdMap bool
}

Expand All @@ -41,11 +40,10 @@ func NewSqlAdapterWithVersionAndArray[T any, K any](db *sql.DB, tableName string
return nil, errors.New("T must be a struct")
}

_, primaryKeys := q.FindPrimaryKeys(modelType)
var k K
kType := reflect.TypeOf(k)
idMap := false
if len(primaryKeys) > 1 {
if len(adapter.Keys) > 1 {
if kType.Kind() == reflect.Map {
idMap = true
} else if kType.Kind() != reflect.Struct {
Expand All @@ -58,7 +56,7 @@ func NewSqlAdapterWithVersionAndArray[T any, K any](db *sql.DB, tableName string
return nil, err
}
fields := q.BuildFieldsBySchema(adapter.Schema)
return &Adapter[T, K]{adapter, fieldsIndex, fields, primaryKeys, idMap}, nil
return &Adapter[T, K]{adapter, fieldsIndex, fields, idMap}, nil
}
func (a *Adapter[T, K]) All(ctx context.Context) ([]T, error) {
var objs []T
Expand Down
60 changes: 53 additions & 7 deletions adapter/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Writer[T any] struct {
DB *sql.DB
Table string
Schema *q.Schema
Keys []string
JsonColumnMap map[string]string
BuildParam func(int) string
Driver string
Expand Down Expand Up @@ -63,10 +64,14 @@ func NewSqlWriterWithVersionAndArray[T any](db *sql.DB, tableName string, versio
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
_, primaryKeys := q.FindPrimaryKeys(modelType)
if len(primaryKeys) == 0 {
return nil, fmt.Errorf("require primary key for table '%s'", tableName)
}
schema := q.CreateSchema(modelType)
jsonColumnMapT := q.MakeJsonColumnMap(modelType)
jsonColumnMap := q.GetWritableColumns(schema.Fields, jsonColumnMapT)
adapter := &Writer[T]{DB: db, Table: tableName, Schema: schema, JsonColumnMap: jsonColumnMap, BuildParam: buildParam, Driver: drivr, BoolSupport: boolSupport, ToArray: toArray, TxKey: "tx", versionIndex: -1}
adapter := &Writer[T]{DB: db, Table: tableName, Schema: schema, Keys: primaryKeys, JsonColumnMap: jsonColumnMap, BuildParam: buildParam, Driver: drivr, BoolSupport: boolSupport, ToArray: toArray, TxKey: "tx", versionIndex: -1}
if len(versionField) > 0 {
index := q.FindFieldIndex(modelType, versionField)
if index >= 0 {
Expand All @@ -87,7 +92,7 @@ func (a *Writer[T]) Create(ctx context.Context, model T) (int64, error) {
query, args := q.BuildToInsertWithVersion(a.Table, model, a.versionIndex, a.BuildParam, a.BoolSupport, a.ToArray, a.Schema)
res, err := tx.ExecContext(ctx, query, args...)
if err != nil {
return -1, err
return q.HandleDuplicate(a.DB, err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
Expand All @@ -113,11 +118,29 @@ func (a *Writer[T]) Update(ctx context.Context, model T) (int64, error) {
if err != nil {
return rowsAffected, err
}
if rowsAffected > 0 && a.versionIndex >= 0 {
vo := reflect.ValueOf(model)
if vo.Kind() == reflect.Ptr {
vo = reflect.Indirect(vo)
vo := reflect.ValueOf(model)
if vo.Kind() == reflect.Ptr {
vo = reflect.Indirect(vo)
}
if rowsAffected < 1 {
var values []interface{}
query1 := fmt.Sprintf("select %s from %s ", a.Schema.SColumns[0], a.Table)
le := len(a.Keys)
var where []string
for i := 0; i < le; i++ {
where = append(where, fmt.Sprintf("%s = %s", a.Schema.Keys[i].Column), a.BuildParam(i+1))
}
query2 := query1 + " where " + strings.Join(where, " and ")
rows, er2 := tx.QueryContext(ctx, query2, values...)
if er2 != nil {
return -1, err
}
defer rows.Close()
for rows.Next() {
return -1, nil
}
return 0, nil
} else if a.versionIndex >= 0 {
currentVersion := vo.Field(a.versionIndex).Interface()
increaseVersion(vo, a.versionIndex, currentVersion)
}
Expand Down Expand Up @@ -159,7 +182,30 @@ func (a *Writer[T]) Patch(ctx context.Context, model map[string]interface{}) (in
if err != nil {
return rowsAffected, err
}
if rowsAffected > 0 && a.versionIndex >= 0 {
if rowsAffected < 1 {
var query2 string
var values []interface{}
query1 := fmt.Sprintf("select %s from %s ", a.Schema.SColumns[0], a.Table)
if len(a.Keys) == 1 {
query2, values = q.BuildFindByIdWithDB(a.DB, query1, model[a.Keys[0]], a.JsonColumnMap, a.Keys, a.BuildParam)
} else {
im := make(map[string]interface{})
le := len(a.Keys)
for i := 0; i < le; i++ {
im[a.Keys[i]] = model[a.Keys[i]]
}
query2, values = q.BuildFindByIdWithDB(a.DB, query1, im, a.JsonColumnMap, a.Keys, a.BuildParam)
}
rows, er2 := tx.QueryContext(ctx, query2, values...)
if er2 != nil {
return -1, err
}
defer rows.Close()
for rows.Next() {
return -1, nil
}
return 0, nil
} else if a.versionIndex >= 0 {
currentVersion, vok := model[a.versionJson]
if !vok {
return -1, fmt.Errorf("%s must be in model for patch", a.versionJson)
Expand Down
134 changes: 134 additions & 0 deletions dao/dao.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package dao

import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"reflect"

q "github.com/core-go/sql"
)

type Dao[T any, K any] struct {
*Writer[*T]
Map map[string]int
Fields string
IdMap bool
}

func NewDao[T any, K any](db *sql.DB, tableName string, opts ...func(int) string) (*Dao[T, K], error) {
return NewSqlDaoWithVersionAndArray[T, K](db, tableName, "", nil, opts...)
}
func NewDaoWithVersion[T any, K any](db *sql.DB, tableName string, versionField string, opts ...func(int) string) (*Dao[T, K], error) {
return NewSqlDaoWithVersionAndArray[T, K](db, tableName, versionField, nil, opts...)
}
func NewSqlDaoWithVersionAndArray[T any, K any](db *sql.DB, tableName string, versionField string, toArray func(interface{}) interface {
driver.Valuer
sql.Scanner
}, opts ...func(int) string) (*Dao[T, K], error) {
adapter, err := NewSqlWriterWithVersionAndArray[*T](db, tableName, versionField, toArray, opts...)
if err != nil {
return nil, err
}

var t T
modelType := reflect.TypeOf(t)
if modelType.Kind() != reflect.Struct {
return nil, errors.New("T must be a struct")
}

var k K
kType := reflect.TypeOf(k)
idMap := false
if len(adapter.Keys) > 1 {
if kType.Kind() == reflect.Map {
idMap = true
} else if kType.Kind() != reflect.Struct {
return nil, errors.New("for composite keys, K must be a struct or a map")
}
}

fieldsIndex, err := q.GetColumnIndexes(modelType)
if err != nil {
return nil, err
}
fields := q.BuildFieldsBySchema(adapter.Schema)
return &Dao[T, K]{adapter, fieldsIndex, fields, idMap}, nil
}
func (a *Dao[T, K]) All(ctx context.Context) ([]T, error) {
var objs []T
query := fmt.Sprintf("select %s from %s", a.Fields, a.Table)
tx := q.GetExec(ctx, a.DB, a.TxKey)
err := q.Query(ctx, tx, a.Map, &objs, query)
return objs, err
}
func toMap(obj interface{}) (map[string]interface{}, error) {
b, err := json.Marshal(obj)
if err != nil {
return nil, err
}
im := make(map[string]interface{})
er2 := json.Unmarshal(b, &im)
return im, er2
}
func (a *Dao[T, K]) getId(k K) (interface{}, error) {
if len(a.Keys) >= 2 && !a.IdMap {
ri, err := toMap(k)
return ri, err
} else {
return k, nil
}
}
func (a *Dao[T, K]) Load(ctx context.Context, id K) (*T, error) {
ip, er0 := a.getId(id)
if er0 != nil {
return nil, er0
}
var objs []T
query := fmt.Sprintf("select %s from %s ", a.Fields, a.Table)
query1, args := q.BuildFindByIdWithDB(a.DB, query, ip, a.JsonColumnMap, a.Keys, a.BuildParam)
tx := q.GetExec(ctx, a.DB, a.TxKey)
err := q.Query(ctx, tx, a.Map, &objs, query1, args...)
if err != nil {
return nil, err
}
if len(objs) > 0 {
return &objs[0], nil
}
return nil, nil
}
func (a *Dao[T, K]) Exist(ctx context.Context, id K) (bool, error) {
ip, er0 := a.getId(id)
if er0 != nil {
return false, er0
}
query := fmt.Sprintf("select %s from %s ", a.Schema.SColumns[0], a.Table)
query1, args := q.BuildFindByIdWithDB(a.DB, query, ip, a.JsonColumnMap, a.Keys, a.BuildParam)
tx := q.GetExec(ctx, a.DB, a.TxKey)
rows, err := tx.QueryContext(ctx, query1, args...)
if err != nil {
return false, err
}
defer rows.Close()
for rows.Next() {
return true, nil
}
return false, nil
}
func (a *Dao[T, K]) Delete(ctx context.Context, id K) (int64, error) {
ip, er0 := a.getId(id)
if er0 != nil {
return -1, er0
}
query := fmt.Sprintf("delete from %s ", a.Table)
query1, args := q.BuildFindByIdWithDB(a.DB, query, ip, a.JsonColumnMap, a.Keys, a.BuildParam)
tx := q.GetExec(ctx, a.DB, a.TxKey)
res, err := tx.ExecContext(ctx, query1, args...)
if err != nil {
return -1, err
}
return res.RowsAffected()
}
65 changes: 65 additions & 0 deletions dao/search.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package dao

import (
"context"
"database/sql"
"database/sql/driver"
"reflect"

q "github.com/core-go/sql"
)

type SearchDao[T any, K any, F any] struct {
*Dao[T, K]
BuildQuery func(F) (string, []interface{})
Mp func(*T)
Map map[string]int
ToArray func(interface{}) interface {
driver.Valuer
sql.Scanner
}
}

func NewSearchDao[T any, K any, F any](db *sql.DB, table string, buildQuery func(F) (string, []interface{}), options ...func(*T)) (*SearchDao[T, K, F], error) {
return NewSearchDaoWithArray[T, K, F](db, table, buildQuery, nil, "", nil, options...)
}
func NewSearchDaoWithVersion[T any, K any, F any](db *sql.DB, table string, buildQuery func(F) (string, []interface{}), versionField string, options ...func(*T)) (*SearchDao[T, K, F], error) {
return NewSearchDaoWithArray[T, K, F](db, table, buildQuery, nil, versionField, nil, options...)
}
func NewSearchDaoWithArray[T any, K any, F any](db *sql.DB, table string, buildQuery func(F) (string, []interface{}), toArray func(interface{}) interface {
driver.Valuer
sql.Scanner
}, versionField string, buildParam func(int) string, opts ...func(*T)) (*SearchDao[T, K, F], error) {
daObj, err := NewSqlDaoWithVersionAndArray[T, K](db, table, versionField, toArray, buildParam)
if err != nil {
return nil, err
}
var mp func(*T)
if len(opts) >= 1 {
mp = opts[0]
}
var t T
modelType := reflect.TypeOf(t)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
fieldsIndex, err := q.GetColumnIndexes(modelType)
if err != nil {
return nil, err
}
builder := &SearchDao[T, K, F]{Dao: daObj, Map: fieldsIndex, BuildQuery: buildQuery, Mp: mp, ToArray: toArray}
return builder, nil
}

func (b *SearchDao[T, K, F]) Search(ctx context.Context, filter F, limit int64, offset int64) ([]T, int64, error) {
var objs []T
query, args := b.BuildQuery(filter)
total, er2 := q.BuildFromQuery(ctx, b.DB, b.Map, &objs, query, args, limit, offset, b.ToArray)
if b.Mp != nil {
l := len(objs)
for i := 0; i < l; i++ {
b.Mp(&objs[i])
}
}
return objs, total, er2
}
59 changes: 59 additions & 0 deletions dao/search_builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package dao

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"reflect"

q "github.com/core-go/sql"
)

type SearchBuilder[T any, F any] struct {
Database *sql.DB
BuildQuery func(F) (string, []interface{})
fieldsIndex map[string]int
Map func(*T)
ToArray func(interface{}) interface {
driver.Valuer
sql.Scanner
}
}

func NewSearchBuilder[T any, F any](db *sql.DB, buildQuery func(F) (string, []interface{}), opts ...func(*T)) (*SearchBuilder[T, F], error) {
return NewSearchBuilderWithArray[T, F](db, buildQuery, nil, opts...)
}
func NewSearchBuilderWithArray[T any, F any](db *sql.DB, buildQuery func(F) (string, []interface{}), toArray func(interface{}) interface {
driver.Valuer
sql.Scanner
}, opts ...func(*T)) (*SearchBuilder[T, F], error) {
var t T
modelType := reflect.TypeOf(t)
if modelType.Kind() != reflect.Struct {
return nil, errors.New("T must be a struct")
}
var mp func(*T)
if len(opts) >= 1 {
mp = opts[0]
}
fieldsIndex, err := q.GetColumnIndexes(modelType)
if err != nil {
return nil, err
}
builder := &SearchBuilder[T, F]{Database: db, fieldsIndex: fieldsIndex, BuildQuery: buildQuery, Map: mp, ToArray: toArray}
return builder, nil
}

func (b *SearchBuilder[T, F]) Search(ctx context.Context, m F, limit int64, offset int64) ([]T, int64, error) {
sql, params := b.BuildQuery(m)
var objs []T
total, er2 := q.BuildFromQuery(ctx, b.Database, b.fieldsIndex, &objs, sql, params, limit, offset, b.ToArray)
if b.Map != nil {
l := len(objs)
for i := 0; i < l; i++ {
b.Map(&objs[i])
}
}
return objs, total, er2
}
Loading

0 comments on commit edaa34c

Please sign in to comment.