Skip to content

Commit

Permalink
feat: Refactor CLI code to handle force sync option
Browse files Browse the repository at this point in the history
- Modified `String` method of `arrayForce` to join strings with a comma
- Added import statements for packages "os" and "github.com/rs/zerolog"
- Initialized logger with ConsoleWriter for standard error output
- Added command line flag for "force" option
- Logged message with tables to force sync if length of "forceSync" array is greater than 0
  • Loading branch information
dz0ny committed Aug 12, 2023
1 parent 8ae0c5e commit 6901ecc
Show file tree
Hide file tree
Showing 12 changed files with 231 additions and 51 deletions.
17 changes: 13 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,26 @@ Utilizing the native PostgreSQL COPY command, pg-subsetter performs data transfe
```
Usage of subsetter:
-dst string
Destination DSN
Destination database DSN
-f float
Fraction of rows to copy (default 0.05)
-force value
Query to copy required tables (users: id = 1)
-src string
Source DSN
Source database DSN
```


Example:
### Example

```pg-subsetter -src postgresql://:@/bigdb -dst postgresql://:@/littledb -f 0.05```
```
pg-subsetter \
-src postgresql://:@/bigdb \
-dst postgresql://:@/littledb \
-f 0.05 \
-force "users: id = 1" \
-force "groups: id = 12"
```

# Installing

Expand Down
27 changes: 27 additions & 0 deletions cli/force.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package main

import (
"fmt"
"strings"
)

type Force struct {
Table string
Where string
}

type arrayForce []Force

func (i *arrayForce) String() string {
return fmt.Sprintf("%v", *i)
}

func (i *arrayForce) Set(value string) error {
q := strings.SplitAfter(strings.TrimSpace(value), ":")

*i = append(*i, Force{
Table: q[0],
Where: q[1],
})
return nil
}
11 changes: 10 additions & 1 deletion cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@ package main

import (
"flag"
"os"

"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)

var src = flag.String("src", "", "Source database DSN")
var dst = flag.String("dst", "", "Destination database DSN")
var fraction = flag.Float64("f", 0.05, "Fraction of rows to copy")
var forceSync arrayForce

func main() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})

flag.Var(&forceSync, "force", "Query to copy required tables (users: id = 1)")
flag.Parse()
log.Info().Msg("Starting")

if *src == "" || *dst == "" {
log.Fatal().Msg("Source and destination DSNs are required")
Expand All @@ -22,4 +27,8 @@ func main() {
log.Fatal().Msg("Fraction must be between 0 and 1")
}

if len(forceSync) > 0 {
log.Info().Msg("Forcing sync for tables: " + forceSync.String())
}

}
1 change: 0 additions & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
golangci-lint
postgresql
process-compose
shellcheck
nixpkgs-fmt
pgweb
];
Expand Down
17 changes: 15 additions & 2 deletions subsetter/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,20 @@ func getTestConnection() *pgx.Conn {
return conn
}

func populateTests(conn *pgx.Conn) {
func getTestConnectionDst() *pgx.Conn {
DATABASE_URL := os.Getenv("DATABASE_URL")
if DATABASE_URL == "" {
DATABASE_URL = "postgres://test_target@localhost:5432/test_target?sslmode=disable"
}

conn, err := pgx.Connect(context.Background(), DATABASE_URL)
if err != nil {
panic(err)
}
return conn
}

func initSchema(conn *pgx.Conn) {
_, err := conn.Exec(context.Background(), `
CREATE TABLE simple (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
Expand Down Expand Up @@ -58,7 +71,7 @@ func populateTestsWithData(conn *pgx.Conn, table string, size int) {
}
}

func clearPopulateTests(conn *pgx.Conn) {
func clearSchema(conn *pgx.Conn) {
_, err := conn.Exec(context.Background(), `
ALTER TABLE relation DROP CONSTRAINT relation_simple_fk;
DROP TABLE simple;
Expand Down
8 changes: 4 additions & 4 deletions subsetter/info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ import (

func TestGetTargetSet(t *testing.T) {
conn := getTestConnection()
populateTests(conn)
initSchema(conn)
defer conn.Close(context.Background())
defer clearPopulateTests(conn)
defer clearSchema(conn)

tests := []struct {
name string
fraction float64
tables []Table
want []Table
}{
{"simple", 0.5, []Table{{"simple", 1000, []string{}}}, []Table{{"simple", 31, []string{}}}},
{"simple", 0.5, []Table{{"simple", 10, []string{}}}, []Table{{"simple", 3, []string{}}}},
{"simple", 0.5, []Table{{"simple", 1000, []Relation{}}}, []Table{{"simple", 31, []Relation{}}}},
{"simple", 0.5, []Table{{"simple", 10, []Relation{}}}, []Table{{"simple", 3, []Relation{}}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
18 changes: 14 additions & 4 deletions subsetter/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
type Table struct {
Name string
Rows int
Relations []string
Relations []Relation
}

func GetTables(conn *pgx.Conn) (tables []string, err error) {
Expand All @@ -37,6 +37,11 @@ func GetTablesWithRows(conn *pgx.Conn) (tables []Table, err error) {
// fix for tables with no rows
if table.Rows == -1 {
table.Rows = 0
} else {
table.Relations, err = GetRelations(table.Name, conn)
if err != nil {
return nil, err
}
}
tables = append(tables, table)
}
Expand All @@ -47,16 +52,21 @@ func GetTablesWithRows(conn *pgx.Conn) (tables []Table, err error) {
return
}

func CopyTableToString(table string, limit int, conn *pgx.Conn) (result string, err error) {
q := fmt.Sprintf(`copy (SELECT * FROM %s order by random() limit %d) to stdout`, table, limit)
func CopyQueryToString(query string, conn *pgx.Conn) (result string, err error) {
q := fmt.Sprintf(`copy (%s) to stdout`, query)
var buff bytes.Buffer
if _, err = conn.PgConn().CopyFrom(context.Background(), &buff, q); err != nil {
if _, err = conn.PgConn().CopyTo(context.Background(), &buff, q); err != nil {
return
}
result = buff.String()
return
}

func CopyTableToString(table string, limit int, conn *pgx.Conn) (result string, err error) {
q := fmt.Sprintf(`SELECT * FROM %s order by random() limit %d`, table, limit)
return CopyQueryToString(q, conn)
}

func CopyStringToTable(table string, data string, conn *pgx.Conn) (err error) {
q := fmt.Sprintf(`copy %s from stdin`, table)
var buff bytes.Buffer
Expand Down
20 changes: 10 additions & 10 deletions subsetter/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (

func TestGetTables(t *testing.T) {
conn := getTestConnection()
populateTests(conn)
initSchema(conn)
defer conn.Close(context.Background())
defer clearPopulateTests(conn)
defer clearSchema(conn)
tests := []struct {
name string
conn *pgx.Conn
Expand All @@ -32,16 +32,16 @@ func TestGetTables(t *testing.T) {

func TestGetTablesWithRows(t *testing.T) {
conn := getTestConnection()
populateTests(conn)
initSchema(conn)
defer conn.Close(context.Background())
defer clearPopulateTests(conn)
defer clearSchema(conn)
tests := []struct {
name string
conn *pgx.Conn
wantTables []Table
wantErr bool
}{
{"With tables", conn, []Table{{"simple", 0, []string{}}, {"relation", 0, []string{}}}, false},
{"With tables", conn, []Table{{"simple", 0, []Relation{}}, {"relation", 0, []Relation{}}}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -62,9 +62,9 @@ func TestGetTablesWithRows(t *testing.T) {

func TestCopyTableToString(t *testing.T) {
conn := getTestConnection()
populateTests(conn)
initSchema(conn)
defer conn.Close(context.Background())
defer clearPopulateTests(conn)
defer clearSchema(conn)
populateTestsWithData(conn, "simple", 10)

tests := []struct {
Expand All @@ -83,7 +83,7 @@ func TestCopyTableToString(t *testing.T) {
t.Errorf("CopyTableToString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if strings.Contains(gotResult, "test") == tt.wantResult {
if strings.Contains(gotResult, "test") != tt.wantResult {
t.Errorf("CopyTableToString() = %v, want %v", gotResult, tt.wantResult)
}
})
Expand All @@ -92,9 +92,9 @@ func TestCopyTableToString(t *testing.T) {

func TestCopyStringToTable(t *testing.T) {
conn := getTestConnection()
populateTests(conn)
initSchema(conn)
defer conn.Close(context.Background())
defer clearPopulateTests(conn)
defer clearSchema(conn)

tests := []struct {
name string
Expand Down
15 changes: 12 additions & 3 deletions subsetter/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,28 @@ package subsetter

import (
"context"
"fmt"

"github.com/jackc/pgx/v5"
)

type Relation struct {
Table string
Column string
Main string
MainColumn string
Related string
RelatedColumn string
}

func (r *Relation) Query() string {
return fmt.Sprintf(`SELECT * FROM %s WHERE %s IN (SELECT %s FROM %s)`, r.Related, r.RelatedColumn, r.MainColumn, r.Main)
}

// GetRelations returns a list of tables that have a foreign key for particular table.
func GetRelations(table string, conn *pgx.Conn) (relations []Relation, err error) {

q := `SELECT
ccu.table_name,
ccu.column_name,
kcu.table_name,
kcu.column_name
FROM
Expand All @@ -28,7 +37,7 @@ func GetRelations(table string, conn *pgx.Conn) (relations []Relation, err error
rows, err := conn.Query(context.Background(), q, table)
for rows.Next() {
var rel Relation
if err := rows.Scan(&rel.Table, &rel.Column); err == nil {
if err := rows.Scan(&rel.Main, &rel.MainColumn, &rel.Related, &rel.RelatedColumn); err == nil {
relations = append(relations, rel)
}
}
Expand Down
23 changes: 20 additions & 3 deletions subsetter/relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ import (

func TestGetRelations(t *testing.T) {
conn := getTestConnection()
populateTests(conn)
initSchema(conn)
defer conn.Close(context.Background())
defer clearPopulateTests(conn)
defer clearSchema(conn)
tests := []struct {
name string
table string
conn *pgx.Conn
wantRelations []Relation
}{
{"With relation", "simple", conn, []Relation{{"relation", "simple_id"}}},
{"With relation", "simple", conn, []Relation{{"simple", "id", "relation", "simple_id"}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -29,3 +29,20 @@ func TestGetRelations(t *testing.T) {
})
}
}

func TestRelation_Query(t *testing.T) {
tests := []struct {
name string
r Relation
want string
}{
{"Simple", Relation{"simple", "id", "relation", "simple_id"}, "SELECT * FROM relation WHERE simple_id IN (SELECT id FROM simple)"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.r.Query(); got != tt.want {
t.Errorf("Relation.Query() = %v, want %v", got, tt.want)
}
})
}
}
Loading

0 comments on commit 6901ecc

Please sign in to comment.