Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dereference fix in common copy implementation #461

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 1 addition & 76 deletions drivers/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
package clickhouse

import (
"context"
"database/sql"
"fmt"
"reflect"
"strconv"
"strings"

Expand Down Expand Up @@ -38,79 +35,7 @@ func init() {
}
return false
},
Copy: CopyWithInsert,
Copy: drivers.CopyWithInsert(func(int) string { return "?" }),
NewMetadataReader: NewMetadataReader,
})
}

// CopyWithInsert builds a copy handler based on insert.
func CopyWithInsert(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
columns, err := rows.Columns()
if err != nil {
return 0, fmt.Errorf("failed to fetch source rows columns: %w", err)
}
clen := len(columns)
query := table
if !strings.HasPrefix(strings.ToLower(query), "insert into") {
leftParen := strings.IndexRune(table, '(')
if leftParen == -1 {
colRows, err := db.QueryContext(ctx, "SELECT * FROM "+table+" WHERE 1=0")
if err != nil {
return 0, fmt.Errorf("failed to execute query to determine target table columns: %w", err)
}
columns, err := colRows.Columns()
_ = colRows.Close()
if err != nil {
return 0, fmt.Errorf("failed to fetch target table columns: %w", err)
}
table += "(" + strings.Join(columns, ", ") + ")"
}
query = "INSERT INTO " + table + " VALUES (" + strings.Repeat("?, ", clen-1) + "?)"
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return 0, fmt.Errorf("failed to begin transaction: %w", err)
}
stmt, err := tx.PrepareContext(ctx, query)
if err != nil {
return 0, fmt.Errorf("failed to prepare insert query: %w", err)
}
defer stmt.Close()
columnTypes, err := rows.ColumnTypes()
if err != nil {
return 0, fmt.Errorf("failed to fetch source column types: %w", err)
}
values := make([]interface{}, clen)
valueRefs := make([]reflect.Value, clen)
actuals := make([]interface{}, clen)
for i := 0; i < len(columnTypes); i++ {
valueRefs[i] = reflect.New(columnTypes[i].ScanType())
values[i] = valueRefs[i].Interface()
}
var n int64
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
return n, fmt.Errorf("failed to scan row: %w", err)
}
//We can't use values... in Exec() below, because, in some cases, clickhouse
//driver doesn't accept pointer to an argument instead of the arg itself.
for i := range values {
actuals[i] = valueRefs[i].Elem().Interface()
}
res, err := stmt.ExecContext(ctx, actuals...)
if err != nil {
return n, fmt.Errorf("failed to exec insert: %w", err)
}
rn, err := res.RowsAffected()
if err != nil {
return n, fmt.Errorf("failed to check rows affected: %w", err)
}
n += rn
}
err = tx.Commit()
if err != nil {
return n, fmt.Errorf("failed to commit transaction: %w", err)
}
return n, rows.Err()
}
20 changes: 12 additions & 8 deletions drivers/drivers.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,16 +540,12 @@ func CopyWithInsert(placeholder func(int) string) func(ctx context.Context, db *
if !strings.HasPrefix(strings.ToLower(query), "insert into") {
leftParen := strings.IndexRune(table, '(')
if leftParen == -1 {
colStmt, err := db.PrepareContext(ctx, "SELECT * FROM "+table+" WHERE 1=0")
if err != nil {
return 0, fmt.Errorf("failed to prepare query to determine target table columns: %w", err)
}
defer colStmt.Close()
colRows, err := colStmt.QueryContext(ctx)
colRows, err := db.QueryContext(ctx, "SELECT * FROM "+table+" WHERE 1=0")
if err != nil {
return 0, fmt.Errorf("failed to execute query to determine target table columns: %w", err)
}
columns, err := colRows.Columns()
_ = colRows.Close()
if err != nil {
return 0, fmt.Errorf("failed to fetch target table columns: %w", err)
}
Expand All @@ -576,16 +572,24 @@ func CopyWithInsert(placeholder func(int) string) func(ctx context.Context, db *
return 0, fmt.Errorf("failed to fetch source column types: %w", err)
}
values := make([]interface{}, clen)
valueRefs := make([]reflect.Value, clen)
actuals := make([]interface{}, clen)
for i := 0; i < len(columnTypes); i++ {
values[i] = reflect.New(columnTypes[i].ScanType()).Interface()
valueRefs[i] = reflect.New(columnTypes[i].ScanType())
values[i] = valueRefs[i].Interface()
}
var n int64
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
return n, fmt.Errorf("failed to scan row: %w", err)
}
res, err := stmt.ExecContext(ctx, values...)
//We can't use values... in Exec() below, because some drivers
//don't accept pointer to an argument instead of the arg itself.
for i := range values {
actuals[i] = valueRefs[i].Elem().Interface()
}
res, err := stmt.ExecContext(ctx, actuals...)
if err != nil {
return n, fmt.Errorf("failed to exec insert: %w", err)
}
Expand Down
117 changes: 77 additions & 40 deletions drivers/drivers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ var (
DSN: "trino://test@localhost:%s/tpch/sf1",
DockerPort: "8080/tcp",
},
"csvq": {
// go test sets working directory to current package regardless of initial working directory
DSN: "csvq://./testdata/csvq",
},
}
cleanup bool
)
Expand Down Expand Up @@ -144,30 +148,21 @@ func TestMain(m *testing.M) {
}

for dbName, db := range dbs {
var ok bool
db.Resource, ok = pool.ContainerByName(db.RunOptions.Name)
if !ok {
buildOpts := &dt.BuildOptions{
ContextDir: "./testdata/docker",
BuildArgs: db.BuildArgs,
}
db.Resource, err = pool.BuildAndRunWithBuildOptions(buildOpts, db.RunOptions)
if err != nil {
log.Fatalf("Could not start %s: %s", dbName, err)
}
}

hostPort := db.Resource.GetPort(db.DockerPort)
db.URL, err = dburl.Parse(fmt.Sprintf(db.DSN, hostPort))
dsn, hostPort := getConnInfo(dbName, db, pool)
db.URL, err = dburl.Parse(dsn)
if err != nil {
log.Fatalf("Failed to parse %s URL %s: %v", dbName, db.DSN, err)
}

if len(db.Exec) != 0 {
readyDSN := db.ReadyDSN
if db.ReadyDSN == "" {
db.ReadyDSN = db.DSN
readyDSN = db.DSN
}
if hostPort != "" {
readyDSN = fmt.Sprintf(db.ReadyDSN, hostPort)
}
readyURL, err := dburl.Parse(fmt.Sprintf(db.ReadyDSN, hostPort))
readyURL, err := dburl.Parse(readyDSN)
if err != nil {
log.Fatalf("Failed to parse %s ready URL %s: %v", dbName, db.ReadyDSN, err)
}
Expand Down Expand Up @@ -205,15 +200,46 @@ func TestMain(m *testing.M) {
// You can't defer this because os.Exit doesn't care for defer
if cleanup {
for _, db := range dbs {
if err := pool.Purge(db.Resource); err != nil {
log.Fatal("Could not purge resource: ", err)
if db.Resource != nil {
if err := pool.Purge(db.Resource); err != nil {
log.Fatal("Could not purge resource: ", err)
}
}
}
}

os.Exit(code)
}

func getConnInfo(dbName string, db *Database, pool *dt.Pool) (string, string) {
if db.RunOptions == nil {
return db.DSN, ""
}

var ok bool
db.Resource, ok = pool.ContainerByName(db.RunOptions.Name)
if ok && !db.Resource.Container.State.Running {
err := db.Resource.Close()
if err != nil {
log.Fatalf("Failed to clean up stale container %s: %s", dbName, err)
}
ok = false
}
if !ok {
buildOpts := &dt.BuildOptions{
ContextDir: "./testdata/docker",
BuildArgs: db.BuildArgs,
}
var err error
db.Resource, err = pool.BuildAndRunWithBuildOptions(buildOpts, db.RunOptions)
if err != nil {
log.Fatalf("Failed to start %s: %s", dbName, err)
}
}
hostPort := db.Resource.GetPort(db.DockerPort)
return fmt.Sprintf(db.DSN, hostPort), hostPort
}

func TestWriter(t *testing.T) {
type testFunc struct {
label string
Expand Down Expand Up @@ -467,37 +493,48 @@ func TestCopy(t *testing.T) {
src: "select first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update from staff",
dest: "staff_copy(first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update)",
},
{
dbName: "csvq",
setupQueries: []setupQuery{
{query: "CREATE TABLE IF NOT EXISTS staff_copy AS SELECT * FROM `staff.csv` WHERE 0=1", check: true},
},
src: "select first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff",
dest: "staff_copy",
},
}
for _, test := range testCases {
db, ok := dbs[test.dbName]
if !ok {
continue
}

// TODO test copy from a different DB, maybe csvq?
// TODO test copy from same DB
t.Run(test.dbName, func(t *testing.T) {

// TODO test copy from a different DB, maybe csvq?
// TODO test copy from same DB

for _, q := range test.setupQueries {
_, err := db.DB.Exec(q.query)
if q.check && err != nil {
log.Fatalf("Failed to run setup query `%s`: %v", q.query, err)
for _, q := range test.setupQueries {
_, err := db.DB.Exec(q.query)
if q.check && err != nil {
t.Fatalf("Failed to run setup query `%s`: %v", q.query, err)
}
}
rows, err := pg.DB.Query(test.src)
if err != nil {
t.Fatalf("Could not get rows to copy: %v", err)
}
}
rows, err := pg.DB.Query(test.src)
if err != nil {
log.Fatalf("Could not get rows to copy: %v", err)
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var rlen int64 = 1
n, err := drivers.Copy(ctx, db.URL, nil, nil, rows, test.dest)
if err != nil {
log.Fatalf("Could not copy: %v", err)
}
if n != rlen {
log.Fatalf("Expected to copy %d rows but got %d", rlen, n)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var rlen int64 = 1
n, err := drivers.Copy(ctx, db.URL, nil, nil, rows, test.dest)
if err != nil {
t.Fatalf("Could not copy: %v", err)
}
if n != rlen {
t.Fatalf("Expected to copy %d rows but got %d", rlen, n)
}
})
}
}

Expand Down
1 change: 1 addition & 0 deletions drivers/testdata/csvq/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*_copy
2 changes: 2 additions & 0 deletions drivers/testdata/csvq/staff.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
first_name,last_name,address_id,email,store_id,active,username,password,last_update
John,Doe,1,[email protected],1,true,jdoe,abc,2024-05-10T08:12:05.46875Z
Loading