Skip to content

Commit

Permalink
Fix xo#321: postgres and pgx support schema name in table destination…
Browse files Browse the repository at this point in the history
… for /copy
  • Loading branch information
murfffi committed Apr 28, 2024
1 parent 76c8780 commit a8e6dd3
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
18 changes: 18 additions & 0 deletions drivers/drivers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,15 @@ func TestCopy(t *testing.T) {
src: "select * from staff",
dest: "staff_copy",
},
{
dbName: "pgsql",
setupQueries: []setupQuery{
{query: "DROP TABLE staff_copy"},
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
},
src: "select * from staff",
dest: "public.staff_copy",
},
{
dbName: "pgx",
setupQueries: []setupQuery{
Expand All @@ -431,6 +440,15 @@ func TestCopy(t *testing.T) {
src: "select * from staff",
dest: "staff_copy",
},
{
dbName: "pgx",
setupQueries: []setupQuery{
{query: "DROP TABLE staff_copy"},
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
},
src: "select * from staff",
dest: "public.staff_copy",
},
{
dbName: "mysql",
setupQueries: []setupQuery{
Expand Down
8 changes: 6 additions & 2 deletions drivers/pgx/pgx.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func init() {
var n int64
err = conn.Raw(func(driverConn interface{}) error {
conn := driverConn.(*stdlib.Conn).Conn()
n, err = conn.CopyFrom(ctx, pgx.Identifier{table}, columns, crows)
n, err = conn.CopyFrom(ctx, pgx.Identifier(strings.SplitN(table, ".", 2)), columns, crows)
return err
})
return n, err
Expand All @@ -141,7 +141,11 @@ func (r *copyRows) Next() bool {

func (r *copyRows) Values() ([]interface{}, error) {
err := r.rows.Scan(r.values...)
return r.values, err
actuals := make([]interface{}, len(r.values))
for i, v := range r.values {
actuals[i] = *(v.(*interface{}))
}
return actuals, err
}

func (r *copyRows) Err() error {
Expand Down
6 changes: 5 additions & 1 deletion drivers/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ func init() {
if err != nil {
return 0, fmt.Errorf("failed to fetch target table columns: %w", err)
}
query = pq.CopyIn(table, columns...)
if schemaSep := strings.Index(table, "."); schemaSep >= 0 {
query = pq.CopyInSchema(table[:schemaSep], table[schemaSep+1:], columns...)
} else {
query = pq.CopyIn(table, columns...)
}
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
Expand Down

0 comments on commit a8e6dd3

Please sign in to comment.