diff --git a/drivers/drivers_test.go b/drivers/drivers_test.go index 1ed8cd5e7c3..3649b002998 100644 --- a/drivers/drivers_test.go +++ b/drivers/drivers_test.go @@ -422,6 +422,15 @@ func TestCopy(t *testing.T) { src: "select * from staff", dest: "staff_copy", }, + { + dbName: "pgsql", + setupQueries: []setupQuery{ + {query: "DROP TABLE staff_copy"}, + {query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true}, + }, + src: "select * from staff", + dest: "public.staff_copy", + }, { dbName: "pgx", setupQueries: []setupQuery{ @@ -431,6 +440,15 @@ func TestCopy(t *testing.T) { src: "select * from staff", dest: "staff_copy", }, + { + dbName: "pgx", + setupQueries: []setupQuery{ + {query: "DROP TABLE staff_copy"}, + {query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true}, + }, + src: "select * from staff", + dest: "public.staff_copy", + }, { dbName: "mysql", setupQueries: []setupQuery{ diff --git a/drivers/pgx/pgx.go b/drivers/pgx/pgx.go index 0af331e8f45..b847124b88d 100644 --- a/drivers/pgx/pgx.go +++ b/drivers/pgx/pgx.go @@ -122,7 +122,7 @@ func init() { var n int64 err = conn.Raw(func(driverConn interface{}) error { conn := driverConn.(*stdlib.Conn).Conn() - n, err = conn.CopyFrom(ctx, pgx.Identifier{table}, columns, crows) + n, err = conn.CopyFrom(ctx, pgx.Identifier(strings.SplitN(table, ".", 2)), columns, crows) return err }) return n, err @@ -141,7 +141,11 @@ func (r *copyRows) Next() bool { func (r *copyRows) Values() ([]interface{}, error) { err := r.rows.Scan(r.values...) - return r.values, err + actuals := make([]interface{}, len(r.values)) + for i, v := range r.values { + actuals[i] = *(v.(*interface{})) + } + return actuals, err } func (r *copyRows) Err() error { diff --git a/drivers/postgres/postgres.go b/drivers/postgres/postgres.go index 31118b23b53..335fed2fd6c 100644 --- a/drivers/postgres/postgres.go +++ b/drivers/postgres/postgres.go @@ -137,7 +137,11 @@ func init() { if err != nil { return 0, fmt.Errorf("failed to fetch target table columns: %w", err) } - query = pq.CopyIn(table, columns...) + if schemaSep := strings.Index(table, "."); schemaSep >= 0 { + query = pq.CopyInSchema(table[:schemaSep], table[schemaSep+1:], columns...) + } else { + query = pq.CopyIn(table, columns...) + } } tx, err := db.BeginTx(ctx, nil) if err != nil {