From 6901ecc6dca9d923c2fc2eb5e80fd70991143220 Mon Sep 17 00:00:00 2001 From: Janez Troha Date: Fri, 11 Aug 2023 21:31:10 +0200 Subject: [PATCH] feat: Refactor CLI code to handle force sync option - Modified `String` method of `arrayForce` to join strings with a comma - Added import statements for packages "os" and "github.com/rs/zerolog" - Initialized logger with ConsoleWriter for standard error output - Added command line flag for "force" option - Logged message with tables to force sync if length of "forceSync" array is greater than 0 --- README.md | 17 +++++++-- cli/force.go | 27 +++++++++++++ cli/main.go | 11 +++++- flake.nix | 1 - subsetter/db_test.go | 17 ++++++++- subsetter/info_test.go | 8 ++-- subsetter/query.go | 18 +++++++-- subsetter/query_test.go | 20 +++++----- subsetter/relations.go | 15 ++++++-- subsetter/relations_test.go | 23 +++++++++-- subsetter/sync.go | 76 +++++++++++++++++++++++++++---------- subsetter/sync_test.go | 49 ++++++++++++++++++++++++ 12 files changed, 231 insertions(+), 51 deletions(-) create mode 100644 cli/force.go create mode 100644 subsetter/sync_test.go diff --git a/README.md b/README.md index 8751565..496951b 100644 --- a/README.md +++ b/README.md @@ -23,17 +23,26 @@ Utilizing the native PostgreSQL COPY command, pg-subsetter performs data transfe ``` Usage of subsetter: -dst string - Destination DSN + Destination database DSN -f float Fraction of rows to copy (default 0.05) + -force value + Query to copy required tables (users: id = 1) -src string - Source DSN + Source database DSN ``` -Example: +### Example -```pg-subsetter -src postgresql://:@/bigdb -dst postgresql://:@/littledb -f 0.05``` +``` +pg-subsetter \ + -src postgresql://:@/bigdb \ + -dst postgresql://:@/littledb \ + -f 0.05 \ + -force "users: id = 1" \ + -force "groups: id = 12" +``` # Installing diff --git a/cli/force.go b/cli/force.go new file mode 100644 index 0000000..d1ee7fc --- /dev/null +++ b/cli/force.go @@ -0,0 +1,27 @@ +package main + +import ( + "fmt" + "strings" +) + +type Force struct { + Table string + Where string +} + +type arrayForce []Force + +func (i *arrayForce) String() string { + return fmt.Sprintf("%v", *i) +} + +func (i *arrayForce) Set(value string) error { + q := strings.SplitAfter(strings.TrimSpace(value), ":") + + *i = append(*i, Force{ + Table: q[0], + Where: q[1], + }) + return nil +} diff --git a/cli/main.go b/cli/main.go index dd07c59..08fdcc3 100644 --- a/cli/main.go +++ b/cli/main.go @@ -2,17 +2,22 @@ package main import ( "flag" + "os" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) var src = flag.String("src", "", "Source database DSN") var dst = flag.String("dst", "", "Destination database DSN") var fraction = flag.Float64("f", 0.05, "Fraction of rows to copy") +var forceSync arrayForce func main() { + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + + flag.Var(&forceSync, "force", "Query to copy required tables (users: id = 1)") flag.Parse() - log.Info().Msg("Starting") if *src == "" || *dst == "" { log.Fatal().Msg("Source and destination DSNs are required") @@ -22,4 +27,8 @@ func main() { log.Fatal().Msg("Fraction must be between 0 and 1") } + if len(forceSync) > 0 { + log.Info().Msg("Forcing sync for tables: " + forceSync.String()) + } + } diff --git a/flake.nix b/flake.nix index 29396f0..6367fbc 100644 --- a/flake.nix +++ b/flake.nix @@ -37,7 +37,6 @@ golangci-lint postgresql process-compose - shellcheck nixpkgs-fmt pgweb ]; diff --git a/subsetter/db_test.go b/subsetter/db_test.go index 5c6ecea..c9d6b31 100644 --- a/subsetter/db_test.go +++ b/subsetter/db_test.go @@ -21,7 +21,20 @@ func getTestConnection() *pgx.Conn { return conn } -func populateTests(conn *pgx.Conn) { +func getTestConnectionDst() *pgx.Conn { + DATABASE_URL := os.Getenv("DATABASE_URL") + if DATABASE_URL == "" { + DATABASE_URL = "postgres://test_target@localhost:5432/test_target?sslmode=disable" + } + + conn, err := pgx.Connect(context.Background(), DATABASE_URL) + if err != nil { + panic(err) + } + return conn +} + +func initSchema(conn *pgx.Conn) { _, err := conn.Exec(context.Background(), ` CREATE TABLE simple ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), @@ -58,7 +71,7 @@ func populateTestsWithData(conn *pgx.Conn, table string, size int) { } } -func clearPopulateTests(conn *pgx.Conn) { +func clearSchema(conn *pgx.Conn) { _, err := conn.Exec(context.Background(), ` ALTER TABLE relation DROP CONSTRAINT relation_simple_fk; DROP TABLE simple; diff --git a/subsetter/info_test.go b/subsetter/info_test.go index 28d843d..a9fb68e 100644 --- a/subsetter/info_test.go +++ b/subsetter/info_test.go @@ -7,9 +7,9 @@ import ( func TestGetTargetSet(t *testing.T) { conn := getTestConnection() - populateTests(conn) + initSchema(conn) defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + defer clearSchema(conn) tests := []struct { name string @@ -17,8 +17,8 @@ func TestGetTargetSet(t *testing.T) { tables []Table want []Table }{ - {"simple", 0.5, []Table{{"simple", 1000, []string{}}}, []Table{{"simple", 31, []string{}}}}, - {"simple", 0.5, []Table{{"simple", 10, []string{}}}, []Table{{"simple", 3, []string{}}}}, + {"simple", 0.5, []Table{{"simple", 1000, []Relation{}}}, []Table{{"simple", 31, []Relation{}}}}, + {"simple", 0.5, []Table{{"simple", 10, []Relation{}}}, []Table{{"simple", 3, []Relation{}}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/subsetter/query.go b/subsetter/query.go index 4bd5856..52cd623 100644 --- a/subsetter/query.go +++ b/subsetter/query.go @@ -11,7 +11,7 @@ import ( type Table struct { Name string Rows int - Relations []string + Relations []Relation } func GetTables(conn *pgx.Conn) (tables []string, err error) { @@ -37,6 +37,11 @@ func GetTablesWithRows(conn *pgx.Conn) (tables []Table, err error) { // fix for tables with no rows if table.Rows == -1 { table.Rows = 0 + } else { + table.Relations, err = GetRelations(table.Name, conn) + if err != nil { + return nil, err + } } tables = append(tables, table) } @@ -47,16 +52,21 @@ func GetTablesWithRows(conn *pgx.Conn) (tables []Table, err error) { return } -func CopyTableToString(table string, limit int, conn *pgx.Conn) (result string, err error) { - q := fmt.Sprintf(`copy (SELECT * FROM %s order by random() limit %d) to stdout`, table, limit) +func CopyQueryToString(query string, conn *pgx.Conn) (result string, err error) { + q := fmt.Sprintf(`copy (%s) to stdout`, query) var buff bytes.Buffer - if _, err = conn.PgConn().CopyFrom(context.Background(), &buff, q); err != nil { + if _, err = conn.PgConn().CopyTo(context.Background(), &buff, q); err != nil { return } result = buff.String() return } +func CopyTableToString(table string, limit int, conn *pgx.Conn) (result string, err error) { + q := fmt.Sprintf(`SELECT * FROM %s order by random() limit %d`, table, limit) + return CopyQueryToString(q, conn) +} + func CopyStringToTable(table string, data string, conn *pgx.Conn) (err error) { q := fmt.Sprintf(`copy %s from stdin`, table) var buff bytes.Buffer diff --git a/subsetter/query_test.go b/subsetter/query_test.go index 0bb4f58..12c98c8 100644 --- a/subsetter/query_test.go +++ b/subsetter/query_test.go @@ -11,9 +11,9 @@ import ( func TestGetTables(t *testing.T) { conn := getTestConnection() - populateTests(conn) + initSchema(conn) defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + defer clearSchema(conn) tests := []struct { name string conn *pgx.Conn @@ -32,16 +32,16 @@ func TestGetTables(t *testing.T) { func TestGetTablesWithRows(t *testing.T) { conn := getTestConnection() - populateTests(conn) + initSchema(conn) defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + defer clearSchema(conn) tests := []struct { name string conn *pgx.Conn wantTables []Table wantErr bool }{ - {"With tables", conn, []Table{{"simple", 0, []string{}}, {"relation", 0, []string{}}}, false}, + {"With tables", conn, []Table{{"simple", 0, []Relation{}}, {"relation", 0, []Relation{}}}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -62,9 +62,9 @@ func TestGetTablesWithRows(t *testing.T) { func TestCopyTableToString(t *testing.T) { conn := getTestConnection() - populateTests(conn) + initSchema(conn) defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + defer clearSchema(conn) populateTestsWithData(conn, "simple", 10) tests := []struct { @@ -83,7 +83,7 @@ func TestCopyTableToString(t *testing.T) { t.Errorf("CopyTableToString() error = %v, wantErr %v", err, tt.wantErr) return } - if strings.Contains(gotResult, "test") == tt.wantResult { + if strings.Contains(gotResult, "test") != tt.wantResult { t.Errorf("CopyTableToString() = %v, want %v", gotResult, tt.wantResult) } }) @@ -92,9 +92,9 @@ func TestCopyTableToString(t *testing.T) { func TestCopyStringToTable(t *testing.T) { conn := getTestConnection() - populateTests(conn) + initSchema(conn) defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + defer clearSchema(conn) tests := []struct { name string diff --git a/subsetter/relations.go b/subsetter/relations.go index ea96639..5420e5e 100644 --- a/subsetter/relations.go +++ b/subsetter/relations.go @@ -2,19 +2,28 @@ package subsetter import ( "context" + "fmt" "github.com/jackc/pgx/v5" ) type Relation struct { - Table string - Column string + Main string + MainColumn string + Related string + RelatedColumn string +} + +func (r *Relation) Query() string { + return fmt.Sprintf(`SELECT * FROM %s WHERE %s IN (SELECT %s FROM %s)`, r.Related, r.RelatedColumn, r.MainColumn, r.Main) } // GetRelations returns a list of tables that have a foreign key for particular table. func GetRelations(table string, conn *pgx.Conn) (relations []Relation, err error) { q := `SELECT + ccu.table_name, + ccu.column_name, kcu.table_name, kcu.column_name FROM @@ -28,7 +37,7 @@ func GetRelations(table string, conn *pgx.Conn) (relations []Relation, err error rows, err := conn.Query(context.Background(), q, table) for rows.Next() { var rel Relation - if err := rows.Scan(&rel.Table, &rel.Column); err == nil { + if err := rows.Scan(&rel.Main, &rel.MainColumn, &rel.Related, &rel.RelatedColumn); err == nil { relations = append(relations, rel) } } diff --git a/subsetter/relations_test.go b/subsetter/relations_test.go index bd42eb0..5d75a42 100644 --- a/subsetter/relations_test.go +++ b/subsetter/relations_test.go @@ -10,16 +10,16 @@ import ( func TestGetRelations(t *testing.T) { conn := getTestConnection() - populateTests(conn) + initSchema(conn) defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + defer clearSchema(conn) tests := []struct { name string table string conn *pgx.Conn wantRelations []Relation }{ - {"With relation", "simple", conn, []Relation{{"relation", "simple_id"}}}, + {"With relation", "simple", conn, []Relation{{"simple", "id", "relation", "simple_id"}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -29,3 +29,20 @@ func TestGetRelations(t *testing.T) { }) } } + +func TestRelation_Query(t *testing.T) { + tests := []struct { + name string + r Relation + want string + }{ + {"Simple", Relation{"simple", "id", "relation", "simple_id"}, "SELECT * FROM relation WHERE simple_id IN (SELECT id FROM simple)"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.Query(); got != tt.want { + t.Errorf("Relation.Query() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/subsetter/sync.go b/subsetter/sync.go index f8dbdd4..7c1542a 100644 --- a/subsetter/sync.go +++ b/subsetter/sync.go @@ -33,21 +33,33 @@ func NewSync(source string, target string, fraction float64, verbose bool) (*Syn }, nil } -func (s *Sync) Sync() (err error) { - var tables []Table - if tables, err = GetTablesWithRows(s.source); err != nil { +// Close closes the connections to the source and destination databases +func (s *Sync) Close() { + s.source.Close(context.Background()) + s.destination.Close(context.Background()) +} + +// copyTableData copies the data from a table in the source database to the destination database +func copyTableData(table Table, source *pgx.Conn, destination *pgx.Conn) (err error) { + var data string + if data, err = CopyTableToString(table.Name, table.Rows, source); err != nil { return } - - var subset []Table - if subset = GetTargetSet(s.fraction, tables); err != nil { + if err = CopyStringToTable(table.Name, data, destination); err != nil { return } + return +} + +// ViableSubset returns a subset of tables that can be copied to the destination database +func ViableSubset(tables []Table) (subset []Table) { + // Filter out tables with no rows - subset = lo.Filter(subset, func(table Table, _ int) bool { return table.Rows > 0 }) + subset = lo.Filter(tables, func(table Table, _ int) bool { return table.Rows > 0 }) - // Generate a list of relations that should be excluded from the subset - relations := lo.FlatMap(subset, func(table Table, _ int) []string { return table.Relations }) + // Get all relations + relationsR := lo.FlatMap(subset, func(table Table, _ int) []Relation { return table.Relations }) + relations := lo.Map(relationsR, func(relation Relation, _ int) string { return relation.Related }) // Filter out tables that are relations of other tables // they will be copied later @@ -55,21 +67,47 @@ func (s *Sync) Sync() (err error) { return !lo.Contains(relations, table.Name) }) - for _, table := range subset { - var data string - if data, err = CopyTableToString(table.Name, table.Rows, s.source); err != nil { + return +} + +// CopyTables copies the data from a list of tables in the source database to the destination database +func (s *Sync) CopyTables(tables []Table) (err error) { + for _, table := range tables { + log.Info().Msgf("Copying table %s", table.Name) + if err = copyTableData(table, s.source, s.destination); err != nil { return } - if err = CopyStringToTable(table.Name, data, s.destination); err != nil { - return + for _, relation := range table.Relations { + // Backtrace the inserted ids from main table to related table + + log.Info().Msgf("Copying relation %s for table %s", relation, table.Name) + var data string + if data, err = CopyQueryToString(relation.Query(), s.source); err != nil { + return + } + if err = CopyStringToTable(table.Name, data, s.destination); err != nil { + return + } } } + return +} - //copy relations, TODO: make this work - for _, table := range subset { - for _, relation := range table.Relations { - log.Debug().Msgf("Copying relation %s for table %s", relation, table.Name) - } +// Sync copies a subset of tables from source to destination +func (s *Sync) Sync() (err error) { + var tables []Table + if tables, err = GetTablesWithRows(s.source); err != nil { + return + } + + var allTables []Table + if allTables = GetTargetSet(s.fraction, tables); err != nil { + return + } + + subset := ViableSubset(allTables) + if err = s.CopyTables(subset); err != nil { + return } return diff --git a/subsetter/sync_test.go b/subsetter/sync_test.go new file mode 100644 index 0000000..e6fec2e --- /dev/null +++ b/subsetter/sync_test.go @@ -0,0 +1,49 @@ +package subsetter + +import ( + "reflect" + "testing" +) + +func TestViableSubset(t *testing.T) { + tests := []struct { + name string + tables []Table + wantSubset []Table + }{ + {"Simple", []Table{{"simple", 10, []Relation{}}}, []Table{{"simple", 10, []Relation{}}}}, + {"No rows", []Table{{"simple", 0, []Relation{}}}, []Table{}}, + { + "Complex, related tables must be excluded", + []Table{{"simple", 10, []Relation{}}, {"complex", 10, []Relation{{"simple", "id", "complex", "simple_id"}}}}, + []Table{{"simple", 10, []Relation{}}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotSubset := ViableSubset(tt.tables); !reflect.DeepEqual(gotSubset, tt.wantSubset) { + t.Errorf("ViableSubset() = %v, want %v", gotSubset, tt.wantSubset) + } + }) + } +} + +func TestSync_CopyTables(t *testing.T) { + src := getTestConnection() + dst := getTestConnectionDst() + initSchema(src) + initSchema(dst) + defer clearSchema(src) + defer clearSchema(dst) + + s := &Sync{ + source: src, + destination: dst, + } + tables := []Table{{"simple", 10, []Relation{}}} + + if err := s.CopyTables(tables); err != nil { + t.Errorf("Sync.CopyTables() error = %v", err) + } + +}