From 577978093f29d2476b5c8e0a0c9e402881b0377f Mon Sep 17 00:00:00 2001 From: Janez Troha Date: Fri, 11 Aug 2023 21:31:10 +0200 Subject: [PATCH] feat: Refactor CLI code to handle force sync option - Modified `String` method of `arrayForce` to join strings with a comma - Added import statements for packages "os" and "github.com/rs/zerolog" - Initialized logger with ConsoleWriter for standard error output - Added command line flag for "force" option - Logged message with tables to force sync if length of "forceSync" array is greater than 0 --- README.md | 21 ++++-- cli/force.go | 24 +++++++ cli/main.go | 27 +++++++- flake.nix | 1 - go.mod | 3 + go.sum | 5 ++ subsetter/db_test.go | 30 +++++++-- subsetter/info_test.go | 11 ++-- subsetter/query.go | 36 ++++++++--- subsetter/query_test.go | 40 ++++++------ subsetter/relations.go | 19 ++++-- subsetter/relations_test.go | 30 +++++++-- subsetter/sync.go | 124 +++++++++++++++++++++++++++++------- subsetter/sync_test.go | 49 ++++++++++++++ 14 files changed, 335 insertions(+), 85 deletions(-) create mode 100644 cli/force.go create mode 100644 subsetter/sync_test.go diff --git a/README.md b/README.md index 8751565..e0d09e0 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![lint](https://github.com/teamniteo/pg-subsetter/actions/workflows/lint.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/lint.yml) [![build](https://github.com/teamniteo/pg-subsetter/actions/workflows/go.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/go.yml) -`pg-subsetter` is a powerful and efficient tool designed to synchronize a fraction of a PostgreSQL database to another PostgreSQL database on the fly, it does not copy the SCHEMA, this means that your target database has to have schema populated in some other way. +`pg-subsetter` is a tool designed to synchronize a fraction of a PostgreSQL database to another PostgreSQL database on the fly, it does not copy the SCHEMA. This means that your target database has to have schema populated in some other way. ### Database Fraction Synchronization `pg-subsetter` allows you to select and sync a specific subset of your database. Whether it's a fraction of a table or a particular dataset, you can have it replicated in another database without synchronizing the entire DB. @@ -12,7 +12,7 @@ Foreign keys play a vital role in maintaining the relationships between tables. `pg-subsetter` ensures that all foreign keys are handled correctly during the synchronization process, maintaining the integrity and relationships of the data. ### Efficient COPY Method -Utilizing the native PostgreSQL COPY command, pg-subsetter performs data transfer with high efficiency. This method significantly speeds up the synchronization process, minimizing downtime and resource consumption. +Utilizing the native PostgreSQL COPY command, `pg-subsetter` performs data transfer with high efficiency. This method significantly speeds up the synchronization process, minimizing downtime and resource consumption. ### Stateless Operation `pg-subsetter` is built to be stateless, meaning it does not maintain any internal state between runs. This ensures that each synchronization process is independent, enhancing reliability and making it easier to manage and scale. @@ -23,17 +23,26 @@ Utilizing the native PostgreSQL COPY command, pg-subsetter performs data transfe ``` Usage of subsetter: -dst string - Destination DSN + Destination database DSN -f float Fraction of rows to copy (default 0.05) + -force value + Query to copy required tables (users: id = 1) -src string - Source DSN + Source database DSN ``` -Example: +### Example -```pg-subsetter -src postgresql://:@/bigdb -dst postgresql://:@/littledb -f 0.05``` +Copy a fraction of the database and force certain rows to be also copied over. + +``` +pg-subsetter \ + -src postgres://Grafana:p66cbcc3dc57f084a5a1587ee1fefda6bdd48f9a33e19e4862950d8ad80cc4f10@ec2-54-195-188-5.eu-west-1.compute.amazonaws.com:5432/dch44sm7ppvtkm \ + -dst postgresql://test_target@localhost/test_target \ + -f 0.05 +``` # Installing diff --git a/cli/force.go b/cli/force.go new file mode 100644 index 0000000..56a4635 --- /dev/null +++ b/cli/force.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "strings" + + "niteo.co/subsetter/subsetter" +) + +type arrayForce []subsetter.Force + +func (i *arrayForce) String() string { + return fmt.Sprintf("%v", *i) +} + +func (i *arrayForce) Set(value string) error { + q := strings.SplitAfter(strings.TrimSpace(value), ":") + + *i = append(*i, subsetter.Force{ + Table: q[0], + Where: q[1], + }) + return nil +} diff --git a/cli/main.go b/cli/main.go index dd07c59..cdd6c27 100644 --- a/cli/main.go +++ b/cli/main.go @@ -2,17 +2,26 @@ package main import ( "flag" + "os" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "github.com/rs/zerolog/pkgerrors" + "niteo.co/subsetter/subsetter" ) var src = flag.String("src", "", "Source database DSN") var dst = flag.String("dst", "", "Destination database DSN") var fraction = flag.Float64("f", 0.05, "Fraction of rows to copy") +var verbose = flag.Bool("verbose", true, "Show more information during sync") +var forceSync arrayForce func main() { + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack + + flag.Var(&forceSync, "force", "Query to copy required tables (users: id = 1)") flag.Parse() - log.Info().Msg("Starting") if *src == "" || *dst == "" { log.Fatal().Msg("Source and destination DSNs are required") @@ -22,4 +31,20 @@ func main() { log.Fatal().Msg("Fraction must be between 0 and 1") } + if len(forceSync) > 0 { + log.Info().Str("forced", forceSync.String()).Msg("Forcing sync for tables") + } + + s, err := subsetter.NewSync(*src, *dst, *fraction, forceSync, *verbose) + if err != nil { + log.Fatal().Stack().Err(err).Msg("Failed to configure sync") + } + + defer s.Close() + + err = s.Sync() + if err != nil { + log.Fatal().Stack().Err(err).Msg("Failed to sync") + } + } diff --git a/flake.nix b/flake.nix index 29396f0..6367fbc 100644 --- a/flake.nix +++ b/flake.nix @@ -37,7 +37,6 @@ golangci-lint postgresql process-compose - shellcheck nixpkgs-fmt pgweb ]; diff --git a/go.mod b/go.mod index 8c953fa..105b568 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,10 @@ require github.com/rs/zerolog v1.30.0 require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/pkg/errors v0.9.1 // indirect golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect + golang.org/x/sync v0.1.0 // indirect ) require ( diff --git a/go.sum b/go.sum index d87d999..95de6cf 100644 --- a/go.sum +++ b/go.sum @@ -8,10 +8,13 @@ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -28,6 +31,8 @@ golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= diff --git a/subsetter/db_test.go b/subsetter/db_test.go index 5c6ecea..700b4d7 100644 --- a/subsetter/db_test.go +++ b/subsetter/db_test.go @@ -4,24 +4,38 @@ import ( "context" "fmt" "os" + "testing" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) -func getTestConnection() *pgx.Conn { +func getTestConnection() *pgxpool.Pool { DATABASE_URL := os.Getenv("DATABASE_URL") if DATABASE_URL == "" { DATABASE_URL = "postgres://test_source@localhost:5432/test_source?sslmode=disable" } - conn, err := pgx.Connect(context.Background(), DATABASE_URL) + conn, err := pgxpool.New(context.Background(), DATABASE_URL) if err != nil { panic(err) } return conn } -func populateTests(conn *pgx.Conn) { +func getTestConnectionDst() *pgxpool.Pool { + DATABASE_URL := os.Getenv("DATABASE_URL_DST") + if DATABASE_URL == "" { + DATABASE_URL = "postgres://test_target@localhost:5432/test_target?sslmode=disable" + } + + conn, err := pgxpool.New(context.Background(), DATABASE_URL) + if err != nil { + panic(err) + } + return conn +} + +func initSchema(conn *pgxpool.Pool) { _, err := conn.Exec(context.Background(), ` CREATE TABLE simple ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), @@ -41,7 +55,7 @@ func populateTests(conn *pgx.Conn) { } } -func populateTestsWithData(conn *pgx.Conn, table string, size int) { +func populateTestsWithData(conn *pgxpool.Pool, table string, size int) { for i := 0; i < size; i++ { query := fmt.Sprintf("INSERT INTO %s (text) VALUES ('test%d') RETURNING id", table, i) var row string @@ -58,7 +72,7 @@ func populateTestsWithData(conn *pgx.Conn, table string, size int) { } } -func clearPopulateTests(conn *pgx.Conn) { +func clearSchema(conn *pgxpool.Pool) { _, err := conn.Exec(context.Background(), ` ALTER TABLE relation DROP CONSTRAINT relation_simple_fk; DROP TABLE simple; @@ -68,3 +82,7 @@ func clearPopulateTests(conn *pgx.Conn) { panic(err) } } + +func TestMOck(t *testing.T) { + populateTestsWithData(getTestConnection(), "simple", 100) +} diff --git a/subsetter/info_test.go b/subsetter/info_test.go index 28d843d..5764207 100644 --- a/subsetter/info_test.go +++ b/subsetter/info_test.go @@ -1,15 +1,14 @@ package subsetter import ( - "context" "testing" ) func TestGetTargetSet(t *testing.T) { conn := getTestConnection() - populateTests(conn) - defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + initSchema(conn) + defer conn.Close() + defer clearSchema(conn) tests := []struct { name string @@ -17,8 +16,8 @@ func TestGetTargetSet(t *testing.T) { tables []Table want []Table }{ - {"simple", 0.5, []Table{{"simple", 1000, []string{}}}, []Table{{"simple", 31, []string{}}}}, - {"simple", 0.5, []Table{{"simple", 10, []string{}}}, []Table{{"simple", 3, []string{}}}}, + {"simple", 0.5, []Table{{"simple", 1000, []Relation{}}}, []Table{{"simple", 31, []Relation{}}}}, + {"simple", 0.5, []Table{{"simple", 10, []Relation{}}}, []Table{{"simple", 3, []Relation{}}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/subsetter/query.go b/subsetter/query.go index 4bd5856..caa44b9 100644 --- a/subsetter/query.go +++ b/subsetter/query.go @@ -5,16 +5,16 @@ import ( "context" "fmt" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) type Table struct { Name string Rows int - Relations []string + Relations []Relation } -func GetTables(conn *pgx.Conn) (tables []string, err error) { +func GetTables(conn *pgxpool.Pool) (tables []string, err error) { q := `SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';` rows, err := conn.Query(context.Background(), q) for rows.Next() { @@ -27,7 +27,7 @@ func GetTables(conn *pgx.Conn) (tables []string, err error) { return } -func GetTablesWithRows(conn *pgx.Conn) (tables []Table, err error) { +func GetTablesWithRows(conn *pgxpool.Pool) (tables []Table, err error) { q := `SELECT relname, reltuples::int FROM pg_class,information_schema.tables WHERE table_schema = 'public' AND relname = table_name;` rows, err := conn.Query(context.Background(), q) for rows.Next() { @@ -37,6 +37,11 @@ func GetTablesWithRows(conn *pgx.Conn) (tables []Table, err error) { // fix for tables with no rows if table.Rows == -1 { table.Rows = 0 + } else { + table.Relations, err = GetRelations(table.Name, conn) + if err != nil { + return nil, err + } } tables = append(tables, table) } @@ -47,21 +52,34 @@ func GetTablesWithRows(conn *pgx.Conn) (tables []Table, err error) { return } -func CopyTableToString(table string, limit int, conn *pgx.Conn) (result string, err error) { - q := fmt.Sprintf(`copy (SELECT * FROM %s order by random() limit %d) to stdout`, table, limit) +func CopyQueryToString(query string, conn *pgxpool.Pool) (result string, err error) { + q := fmt.Sprintf(`copy (%s) to stdout`, query) var buff bytes.Buffer - if _, err = conn.PgConn().CopyFrom(context.Background(), &buff, q); err != nil { + c, err := conn.Acquire(context.Background()) + if err != nil { + return + } + if _, err = c.Conn().PgConn().CopyTo(context.Background(), &buff, q); err != nil { return } result = buff.String() return } -func CopyStringToTable(table string, data string, conn *pgx.Conn) (err error) { +func CopyTableToString(table string, limit int, conn *pgxpool.Pool) (result string, err error) { + q := fmt.Sprintf(`SELECT * FROM %s order by random() limit %d`, table, limit) + return CopyQueryToString(q, conn) +} + +func CopyStringToTable(table string, data string, conn *pgxpool.Pool) (err error) { q := fmt.Sprintf(`copy %s from stdin`, table) var buff bytes.Buffer buff.WriteString(data) - if _, err = conn.PgConn().CopyFrom(context.Background(), &buff, q); err != nil { + c, err := conn.Acquire(context.Background()) + if err != nil { + return + } + if _, err = c.Conn().PgConn().CopyFrom(context.Background(), &buff, q); err != nil { return } diff --git a/subsetter/query_test.go b/subsetter/query_test.go index 0bb4f58..be3a9df 100644 --- a/subsetter/query_test.go +++ b/subsetter/query_test.go @@ -6,17 +6,17 @@ import ( "strings" "testing" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) func TestGetTables(t *testing.T) { conn := getTestConnection() - populateTests(conn) - defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + initSchema(conn) + defer conn.Close() + defer clearSchema(conn) tests := []struct { name string - conn *pgx.Conn + conn *pgxpool.Pool wantTables []string }{ {"With tables", conn, []string{"simple", "relation"}}, @@ -32,16 +32,16 @@ func TestGetTables(t *testing.T) { func TestGetTablesWithRows(t *testing.T) { conn := getTestConnection() - populateTests(conn) - defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + initSchema(conn) + defer conn.Close() + defer clearSchema(conn) tests := []struct { name string - conn *pgx.Conn + conn *pgxpool.Pool wantTables []Table wantErr bool }{ - {"With tables", conn, []Table{{"simple", 0, []string{}}, {"relation", 0, []string{}}}, false}, + {"With tables", conn, []Table{{"simple", 0, []Relation{}}, {"relation", 0, []Relation{}}}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -62,15 +62,15 @@ func TestGetTablesWithRows(t *testing.T) { func TestCopyTableToString(t *testing.T) { conn := getTestConnection() - populateTests(conn) - defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + initSchema(conn) + defer conn.Close() + defer clearSchema(conn) populateTestsWithData(conn, "simple", 10) tests := []struct { name string table string - conn *pgx.Conn + conn *pgxpool.Pool wantResult bool wantErr bool }{ @@ -83,7 +83,7 @@ func TestCopyTableToString(t *testing.T) { t.Errorf("CopyTableToString() error = %v, wantErr %v", err, tt.wantErr) return } - if strings.Contains(gotResult, "test") == tt.wantResult { + if strings.Contains(gotResult, "test") != tt.wantResult { t.Errorf("CopyTableToString() = %v, want %v", gotResult, tt.wantResult) } }) @@ -92,15 +92,15 @@ func TestCopyTableToString(t *testing.T) { func TestCopyStringToTable(t *testing.T) { conn := getTestConnection() - populateTests(conn) - defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + initSchema(conn) + defer conn.Close() + defer clearSchema(conn) tests := []struct { name string table string data string - conn *pgx.Conn + conn *pgxpool.Pool wantResult int wantErr bool }{ @@ -123,7 +123,7 @@ func TestCopyStringToTable(t *testing.T) { } } -func insertedRows(s string, conn *pgx.Conn) int { +func insertedRows(s string, conn *pgxpool.Pool) int { q := "SELECT count(*) FROM " + s var count int err := conn.QueryRow(context.Background(), q).Scan(&count) diff --git a/subsetter/relations.go b/subsetter/relations.go index ea96639..b59b585 100644 --- a/subsetter/relations.go +++ b/subsetter/relations.go @@ -2,19 +2,28 @@ package subsetter import ( "context" + "fmt" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) type Relation struct { - Table string - Column string + Main string + MainColumn string + Related string + RelatedColumn string +} + +func (r *Relation) Query() string { + return fmt.Sprintf(`SELECT * FROM %s WHERE %s IN (SELECT %s FROM %s)`, r.Related, r.RelatedColumn, r.MainColumn, r.Main) } // GetRelations returns a list of tables that have a foreign key for particular table. -func GetRelations(table string, conn *pgx.Conn) (relations []Relation, err error) { +func GetRelations(table string, conn *pgxpool.Pool) (relations []Relation, err error) { q := `SELECT + ccu.table_name, + ccu.column_name, kcu.table_name, kcu.column_name FROM @@ -28,7 +37,7 @@ func GetRelations(table string, conn *pgx.Conn) (relations []Relation, err error rows, err := conn.Query(context.Background(), q, table) for rows.Next() { var rel Relation - if err := rows.Scan(&rel.Table, &rel.Column); err == nil { + if err := rows.Scan(&rel.Main, &rel.MainColumn, &rel.Related, &rel.RelatedColumn); err == nil { relations = append(relations, rel) } } diff --git a/subsetter/relations_test.go b/subsetter/relations_test.go index bd42eb0..0181b46 100644 --- a/subsetter/relations_test.go +++ b/subsetter/relations_test.go @@ -1,25 +1,24 @@ package subsetter import ( - "context" "reflect" "testing" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) func TestGetRelations(t *testing.T) { conn := getTestConnection() - populateTests(conn) - defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + initSchema(conn) + defer conn.Close() + defer clearSchema(conn) tests := []struct { name string table string - conn *pgx.Conn + conn *pgxpool.Pool wantRelations []Relation }{ - {"With relation", "simple", conn, []Relation{{"relation", "simple_id"}}}, + {"With relation", "simple", conn, []Relation{{"simple", "id", "relation", "simple_id"}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -29,3 +28,20 @@ func TestGetRelations(t *testing.T) { }) } } + +func TestRelation_Query(t *testing.T) { + tests := []struct { + name string + r Relation + want string + }{ + {"Simple", Relation{"simple", "id", "relation", "simple_id"}, "SELECT * FROM relation WHERE simple_id IN (SELECT id FROM simple)"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.Query(); got != tt.want { + t.Errorf("Relation.Query() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/subsetter/sync.go b/subsetter/sync.go index f8dbdd4..8d47481 100644 --- a/subsetter/sync.go +++ b/subsetter/sync.go @@ -3,24 +3,40 @@ package subsetter import ( "context" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "github.com/rs/zerolog/log" "github.com/samber/lo" ) type Sync struct { - source *pgx.Conn - destination *pgx.Conn + source *pgxpool.Pool + destination *pgxpool.Pool fraction float64 verbose bool + force []Force } -func NewSync(source string, target string, fraction float64, verbose bool) (*Sync, error) { - src, err := pgx.Connect(context.Background(), source) +type Force struct { + Table string + Where string +} + +func NewSync(source string, target string, fraction float64, force []Force, verbose bool) (*Sync, error) { + src, err := pgxpool.New(context.Background(), source) + if err != nil { + return nil, err + } + + src.Ping(context.Background()) if err != nil { return nil, err } - dst, err := pgx.Connect(context.Background(), source) + + dst, err := pgxpool.New(context.Background(), source) + if err != nil { + return nil, err + } + dst.Ping(context.Background()) if err != nil { return nil, err } @@ -30,24 +46,37 @@ func NewSync(source string, target string, fraction float64, verbose bool) (*Syn destination: dst, fraction: fraction, verbose: verbose, + force: force, }, nil } -func (s *Sync) Sync() (err error) { - var tables []Table - if tables, err = GetTablesWithRows(s.source); err != nil { +// Close closes the connections to the source and destination databases +func (s *Sync) Close() { + s.source.Close() + s.destination.Close() +} + +// copyTableData copies the data from a table in the source database to the destination database +func copyTableData(table Table, source *pgxpool.Pool, destination *pgxpool.Pool) (err error) { + var data string + if data, err = CopyTableToString(table.Name, table.Rows, source); err != nil { return } - - var subset []Table - if subset = GetTargetSet(s.fraction, tables); err != nil { + if err = CopyStringToTable(table.Name, data, destination); err != nil { return } + return +} + +// ViableSubset returns a subset of tables that can be copied to the destination database +func ViableSubset(tables []Table) (subset []Table) { + // Filter out tables with no rows - subset = lo.Filter(subset, func(table Table, _ int) bool { return table.Rows > 0 }) + subset = lo.Filter(tables, func(table Table, _ int) bool { return table.Rows > 0 }) - // Generate a list of relations that should be excluded from the subset - relations := lo.FlatMap(subset, func(table Table, _ int) []string { return table.Relations }) + // Get all relations + relationsR := lo.FlatMap(subset, func(table Table, _ int) []Relation { return table.Relations }) + relations := lo.Map(relationsR, func(relation Relation, _ int) string { return relation.Related }) // Filter out tables that are relations of other tables // they will be copied later @@ -55,22 +84,69 @@ func (s *Sync) Sync() (err error) { return !lo.Contains(relations, table.Name) }) - for _, table := range subset { - var data string - if data, err = CopyTableToString(table.Name, table.Rows, s.source); err != nil { + return +} + +// CopyTables copies the data from a list of tables in the source database to the destination database +func (s *Sync) CopyTables(tables []Table) (err error) { + for _, table := range tables { + log.Info().Msgf("Copying table %s", table.Name) + if err = copyTableData(table, s.source, s.destination); err != nil { return } - if err = CopyStringToTable(table.Name, data, s.destination); err != nil { - return + + for _, force := range s.force { + if force.Table == table.Name { + log.Info().Msgf("Selecting forced rows for table %s", table.Name) + var data string + if data, err = CopyQueryToString(force.Where, s.source); err != nil { + return + } + if err = CopyStringToTable(table.Name, data, s.destination); err != nil { + return + } + } } - } - //copy relations, TODO: make this work - for _, table := range subset { for _, relation := range table.Relations { - log.Debug().Msgf("Copying relation %s for table %s", relation, table.Name) + // Backtrace the inserted ids from main table to related table + + log.Info().Msgf("Copying relation %s for table %s", relation, table.Name) + var data string + if data, err = CopyQueryToString(relation.Query(), s.source); err != nil { + return + } + if err = CopyStringToTable(table.Name, data, s.destination); err != nil { + return + } + } + } + return +} + +// Sync copies a subset of tables from source to destination +func (s *Sync) Sync() (err error) { + var tables []Table + if tables, err = GetTablesWithRows(s.source); err != nil { + return + } + + var allTables []Table + if allTables = GetTargetSet(s.fraction, tables); err != nil { + return + } + + subset := ViableSubset(allTables) + + if s.verbose { + for _, T := range subset { + log.Info().Msgf("Copying table %s with %d rows", T.Name, T.Rows) } } + if err = s.CopyTables(subset); err != nil { + return + } + return } diff --git a/subsetter/sync_test.go b/subsetter/sync_test.go new file mode 100644 index 0000000..e6fec2e --- /dev/null +++ b/subsetter/sync_test.go @@ -0,0 +1,49 @@ +package subsetter + +import ( + "reflect" + "testing" +) + +func TestViableSubset(t *testing.T) { + tests := []struct { + name string + tables []Table + wantSubset []Table + }{ + {"Simple", []Table{{"simple", 10, []Relation{}}}, []Table{{"simple", 10, []Relation{}}}}, + {"No rows", []Table{{"simple", 0, []Relation{}}}, []Table{}}, + { + "Complex, related tables must be excluded", + []Table{{"simple", 10, []Relation{}}, {"complex", 10, []Relation{{"simple", "id", "complex", "simple_id"}}}}, + []Table{{"simple", 10, []Relation{}}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotSubset := ViableSubset(tt.tables); !reflect.DeepEqual(gotSubset, tt.wantSubset) { + t.Errorf("ViableSubset() = %v, want %v", gotSubset, tt.wantSubset) + } + }) + } +} + +func TestSync_CopyTables(t *testing.T) { + src := getTestConnection() + dst := getTestConnectionDst() + initSchema(src) + initSchema(dst) + defer clearSchema(src) + defer clearSchema(dst) + + s := &Sync{ + source: src, + destination: dst, + } + tables := []Table{{"simple", 10, []Relation{}}} + + if err := s.CopyTables(tables); err != nil { + t.Errorf("Sync.CopyTables() error = %v", err) + } + +}