diff --git a/README.md b/README.md index 8751565..e0d09e0 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![lint](https://github.com/teamniteo/pg-subsetter/actions/workflows/lint.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/lint.yml) [![build](https://github.com/teamniteo/pg-subsetter/actions/workflows/go.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/go.yml) -`pg-subsetter` is a powerful and efficient tool designed to synchronize a fraction of a PostgreSQL database to another PostgreSQL database on the fly, it does not copy the SCHEMA, this means that your target database has to have schema populated in some other way. +`pg-subsetter` is a tool designed to synchronize a fraction of a PostgreSQL database to another PostgreSQL database on the fly, it does not copy the SCHEMA. This means that your target database has to have schema populated in some other way. ### Database Fraction Synchronization `pg-subsetter` allows you to select and sync a specific subset of your database. Whether it's a fraction of a table or a particular dataset, you can have it replicated in another database without synchronizing the entire DB. @@ -12,7 +12,7 @@ Foreign keys play a vital role in maintaining the relationships between tables. `pg-subsetter` ensures that all foreign keys are handled correctly during the synchronization process, maintaining the integrity and relationships of the data. ### Efficient COPY Method -Utilizing the native PostgreSQL COPY command, pg-subsetter performs data transfer with high efficiency. This method significantly speeds up the synchronization process, minimizing downtime and resource consumption. +Utilizing the native PostgreSQL COPY command, `pg-subsetter` performs data transfer with high efficiency. This method significantly speeds up the synchronization process, minimizing downtime and resource consumption. ### Stateless Operation `pg-subsetter` is built to be stateless, meaning it does not maintain any internal state between runs. This ensures that each synchronization process is independent, enhancing reliability and making it easier to manage and scale. @@ -23,17 +23,26 @@ Utilizing the native PostgreSQL COPY command, pg-subsetter performs data transfe ``` Usage of subsetter: -dst string - Destination DSN + Destination database DSN -f float Fraction of rows to copy (default 0.05) + -force value + Query to copy required tables (users: id = 1) -src string - Source DSN + Source database DSN ``` -Example: +### Example -```pg-subsetter -src postgresql://:@/bigdb -dst postgresql://:@/littledb -f 0.05``` +Copy a fraction of the database and force certain rows to be also copied over. + +``` +pg-subsetter \ + -src postgres://Grafana:p66cbcc3dc57f084a5a1587ee1fefda6bdd48f9a33e19e4862950d8ad80cc4f10@ec2-54-195-188-5.eu-west-1.compute.amazonaws.com:5432/dch44sm7ppvtkm \ + -dst postgresql://test_target@localhost/test_target \ + -f 0.05 +``` # Installing diff --git a/cli/force.go b/cli/force.go new file mode 100644 index 0000000..56a4635 --- /dev/null +++ b/cli/force.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "strings" + + "niteo.co/subsetter/subsetter" +) + +type arrayForce []subsetter.Force + +func (i *arrayForce) String() string { + return fmt.Sprintf("%v", *i) +} + +func (i *arrayForce) Set(value string) error { + q := strings.SplitAfter(strings.TrimSpace(value), ":") + + *i = append(*i, subsetter.Force{ + Table: q[0], + Where: q[1], + }) + return nil +} diff --git a/cli/main.go b/cli/main.go index dd07c59..71687c3 100644 --- a/cli/main.go +++ b/cli/main.go @@ -2,17 +2,26 @@ package main import ( "flag" + "os" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "github.com/rs/zerolog/pkgerrors" + "niteo.co/subsetter/subsetter" ) var src = flag.String("src", "", "Source database DSN") var dst = flag.String("dst", "", "Destination database DSN") var fraction = flag.Float64("f", 0.05, "Fraction of rows to copy") +var verbose = flag.Bool("verbose", false, "Show more information during sync") +var forceSync arrayForce func main() { + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack + + flag.Var(&forceSync, "force", "Query to copy required tables (users: id = 1)") flag.Parse() - log.Info().Msg("Starting") if *src == "" || *dst == "" { log.Fatal().Msg("Source and destination DSNs are required") @@ -22,4 +31,20 @@ func main() { log.Fatal().Msg("Fraction must be between 0 and 1") } + if len(forceSync) > 0 { + log.Info().Str("forced", forceSync.String()).Msg("Forcing sync for tables") + } + + s, err := subsetter.NewSync(*src, *dst, *fraction, forceSync, *verbose) + if err != nil { + log.Fatal().Stack().Err(err).Msg("Failed to configure sync") + } + + defer s.Close() + + err = s.Sync() + if err != nil { + log.Fatal().Stack().Err(err).Msg("Failed to sync") + } + } diff --git a/flake.nix b/flake.nix index 29396f0..6367fbc 100644 --- a/flake.nix +++ b/flake.nix @@ -37,7 +37,6 @@ golangci-lint postgresql process-compose - shellcheck nixpkgs-fmt pgweb ]; diff --git a/go.mod b/go.mod index 8c953fa..608395d 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require github.com/rs/zerolog v1.30.0 require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/pkg/errors v0.9.1 // indirect golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect ) diff --git a/go.sum b/go.sum index d87d999..2433299 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,7 @@ github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZb github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/subsetter/db_test.go b/subsetter/db_test.go index 5c6ecea..4382f3f 100644 --- a/subsetter/db_test.go +++ b/subsetter/db_test.go @@ -21,7 +21,20 @@ func getTestConnection() *pgx.Conn { return conn } -func populateTests(conn *pgx.Conn) { +func getTestConnectionDst() *pgx.Conn { + DATABASE_URL := os.Getenv("DATABASE_URL_DST") + if DATABASE_URL == "" { + DATABASE_URL = "postgres://test_target@localhost:5432/test_target?sslmode=disable" + } + + conn, err := pgx.Connect(context.Background(), DATABASE_URL) + if err != nil { + panic(err) + } + return conn +} + +func initSchema(conn *pgx.Conn) { _, err := conn.Exec(context.Background(), ` CREATE TABLE simple ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), @@ -58,7 +71,7 @@ func populateTestsWithData(conn *pgx.Conn, table string, size int) { } } -func clearPopulateTests(conn *pgx.Conn) { +func clearSchema(conn *pgx.Conn) { _, err := conn.Exec(context.Background(), ` ALTER TABLE relation DROP CONSTRAINT relation_simple_fk; DROP TABLE simple; diff --git a/subsetter/info_test.go b/subsetter/info_test.go index 28d843d..a9fb68e 100644 --- a/subsetter/info_test.go +++ b/subsetter/info_test.go @@ -7,9 +7,9 @@ import ( func TestGetTargetSet(t *testing.T) { conn := getTestConnection() - populateTests(conn) + initSchema(conn) defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + defer clearSchema(conn) tests := []struct { name string @@ -17,8 +17,8 @@ func TestGetTargetSet(t *testing.T) { tables []Table want []Table }{ - {"simple", 0.5, []Table{{"simple", 1000, []string{}}}, []Table{{"simple", 31, []string{}}}}, - {"simple", 0.5, []Table{{"simple", 10, []string{}}}, []Table{{"simple", 3, []string{}}}}, + {"simple", 0.5, []Table{{"simple", 1000, []Relation{}}}, []Table{{"simple", 31, []Relation{}}}}, + {"simple", 0.5, []Table{{"simple", 10, []Relation{}}}, []Table{{"simple", 3, []Relation{}}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/subsetter/query.go b/subsetter/query.go index 4bd5856..52cd623 100644 --- a/subsetter/query.go +++ b/subsetter/query.go @@ -11,7 +11,7 @@ import ( type Table struct { Name string Rows int - Relations []string + Relations []Relation } func GetTables(conn *pgx.Conn) (tables []string, err error) { @@ -37,6 +37,11 @@ func GetTablesWithRows(conn *pgx.Conn) (tables []Table, err error) { // fix for tables with no rows if table.Rows == -1 { table.Rows = 0 + } else { + table.Relations, err = GetRelations(table.Name, conn) + if err != nil { + return nil, err + } } tables = append(tables, table) } @@ -47,16 +52,21 @@ func GetTablesWithRows(conn *pgx.Conn) (tables []Table, err error) { return } -func CopyTableToString(table string, limit int, conn *pgx.Conn) (result string, err error) { - q := fmt.Sprintf(`copy (SELECT * FROM %s order by random() limit %d) to stdout`, table, limit) +func CopyQueryToString(query string, conn *pgx.Conn) (result string, err error) { + q := fmt.Sprintf(`copy (%s) to stdout`, query) var buff bytes.Buffer - if _, err = conn.PgConn().CopyFrom(context.Background(), &buff, q); err != nil { + if _, err = conn.PgConn().CopyTo(context.Background(), &buff, q); err != nil { return } result = buff.String() return } +func CopyTableToString(table string, limit int, conn *pgx.Conn) (result string, err error) { + q := fmt.Sprintf(`SELECT * FROM %s order by random() limit %d`, table, limit) + return CopyQueryToString(q, conn) +} + func CopyStringToTable(table string, data string, conn *pgx.Conn) (err error) { q := fmt.Sprintf(`copy %s from stdin`, table) var buff bytes.Buffer diff --git a/subsetter/query_test.go b/subsetter/query_test.go index 0bb4f58..12c98c8 100644 --- a/subsetter/query_test.go +++ b/subsetter/query_test.go @@ -11,9 +11,9 @@ import ( func TestGetTables(t *testing.T) { conn := getTestConnection() - populateTests(conn) + initSchema(conn) defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + defer clearSchema(conn) tests := []struct { name string conn *pgx.Conn @@ -32,16 +32,16 @@ func TestGetTables(t *testing.T) { func TestGetTablesWithRows(t *testing.T) { conn := getTestConnection() - populateTests(conn) + initSchema(conn) defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + defer clearSchema(conn) tests := []struct { name string conn *pgx.Conn wantTables []Table wantErr bool }{ - {"With tables", conn, []Table{{"simple", 0, []string{}}, {"relation", 0, []string{}}}, false}, + {"With tables", conn, []Table{{"simple", 0, []Relation{}}, {"relation", 0, []Relation{}}}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -62,9 +62,9 @@ func TestGetTablesWithRows(t *testing.T) { func TestCopyTableToString(t *testing.T) { conn := getTestConnection() - populateTests(conn) + initSchema(conn) defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + defer clearSchema(conn) populateTestsWithData(conn, "simple", 10) tests := []struct { @@ -83,7 +83,7 @@ func TestCopyTableToString(t *testing.T) { t.Errorf("CopyTableToString() error = %v, wantErr %v", err, tt.wantErr) return } - if strings.Contains(gotResult, "test") == tt.wantResult { + if strings.Contains(gotResult, "test") != tt.wantResult { t.Errorf("CopyTableToString() = %v, want %v", gotResult, tt.wantResult) } }) @@ -92,9 +92,9 @@ func TestCopyTableToString(t *testing.T) { func TestCopyStringToTable(t *testing.T) { conn := getTestConnection() - populateTests(conn) + initSchema(conn) defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + defer clearSchema(conn) tests := []struct { name string diff --git a/subsetter/relations.go b/subsetter/relations.go index ea96639..5420e5e 100644 --- a/subsetter/relations.go +++ b/subsetter/relations.go @@ -2,19 +2,28 @@ package subsetter import ( "context" + "fmt" "github.com/jackc/pgx/v5" ) type Relation struct { - Table string - Column string + Main string + MainColumn string + Related string + RelatedColumn string +} + +func (r *Relation) Query() string { + return fmt.Sprintf(`SELECT * FROM %s WHERE %s IN (SELECT %s FROM %s)`, r.Related, r.RelatedColumn, r.MainColumn, r.Main) } // GetRelations returns a list of tables that have a foreign key for particular table. func GetRelations(table string, conn *pgx.Conn) (relations []Relation, err error) { q := `SELECT + ccu.table_name, + ccu.column_name, kcu.table_name, kcu.column_name FROM @@ -28,7 +37,7 @@ func GetRelations(table string, conn *pgx.Conn) (relations []Relation, err error rows, err := conn.Query(context.Background(), q, table) for rows.Next() { var rel Relation - if err := rows.Scan(&rel.Table, &rel.Column); err == nil { + if err := rows.Scan(&rel.Main, &rel.MainColumn, &rel.Related, &rel.RelatedColumn); err == nil { relations = append(relations, rel) } } diff --git a/subsetter/relations_test.go b/subsetter/relations_test.go index bd42eb0..5d75a42 100644 --- a/subsetter/relations_test.go +++ b/subsetter/relations_test.go @@ -10,16 +10,16 @@ import ( func TestGetRelations(t *testing.T) { conn := getTestConnection() - populateTests(conn) + initSchema(conn) defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + defer clearSchema(conn) tests := []struct { name string table string conn *pgx.Conn wantRelations []Relation }{ - {"With relation", "simple", conn, []Relation{{"relation", "simple_id"}}}, + {"With relation", "simple", conn, []Relation{{"simple", "id", "relation", "simple_id"}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -29,3 +29,20 @@ func TestGetRelations(t *testing.T) { }) } } + +func TestRelation_Query(t *testing.T) { + tests := []struct { + name string + r Relation + want string + }{ + {"Simple", Relation{"simple", "id", "relation", "simple_id"}, "SELECT * FROM relation WHERE simple_id IN (SELECT id FROM simple)"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.Query(); got != tt.want { + t.Errorf("Relation.Query() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/subsetter/sync.go b/subsetter/sync.go index f8dbdd4..07d268a 100644 --- a/subsetter/sync.go +++ b/subsetter/sync.go @@ -13,9 +13,15 @@ type Sync struct { destination *pgx.Conn fraction float64 verbose bool + force []Force } -func NewSync(source string, target string, fraction float64, verbose bool) (*Sync, error) { +type Force struct { + Table string + Where string +} + +func NewSync(source string, target string, fraction float64, force []Force, verbose bool) (*Sync, error) { src, err := pgx.Connect(context.Background(), source) if err != nil { return nil, err @@ -30,24 +36,37 @@ func NewSync(source string, target string, fraction float64, verbose bool) (*Syn destination: dst, fraction: fraction, verbose: verbose, + force: force, }, nil } -func (s *Sync) Sync() (err error) { - var tables []Table - if tables, err = GetTablesWithRows(s.source); err != nil { +// Close closes the connections to the source and destination databases +func (s *Sync) Close() { + s.source.Close(context.Background()) + s.destination.Close(context.Background()) +} + +// copyTableData copies the data from a table in the source database to the destination database +func copyTableData(table Table, source *pgx.Conn, destination *pgx.Conn) (err error) { + var data string + if data, err = CopyTableToString(table.Name, table.Rows, source); err != nil { return } - - var subset []Table - if subset = GetTargetSet(s.fraction, tables); err != nil { + if err = CopyStringToTable(table.Name, data, destination); err != nil { return } + return +} + +// ViableSubset returns a subset of tables that can be copied to the destination database +func ViableSubset(tables []Table) (subset []Table) { + // Filter out tables with no rows - subset = lo.Filter(subset, func(table Table, _ int) bool { return table.Rows > 0 }) + subset = lo.Filter(tables, func(table Table, _ int) bool { return table.Rows > 0 }) - // Generate a list of relations that should be excluded from the subset - relations := lo.FlatMap(subset, func(table Table, _ int) []string { return table.Relations }) + // Get all relations + relationsR := lo.FlatMap(subset, func(table Table, _ int) []Relation { return table.Relations }) + relations := lo.Map(relationsR, func(relation Relation, _ int) string { return relation.Related }) // Filter out tables that are relations of other tables // they will be copied later @@ -55,22 +74,68 @@ func (s *Sync) Sync() (err error) { return !lo.Contains(relations, table.Name) }) - for _, table := range subset { - var data string - if data, err = CopyTableToString(table.Name, table.Rows, s.source); err != nil { + return +} + +// CopyTables copies the data from a list of tables in the source database to the destination database +func (s *Sync) CopyTables(tables []Table) (err error) { + for _, table := range tables { + log.Info().Msgf("Copying table %s", table.Name) + if err = copyTableData(table, s.source, s.destination); err != nil { return } - if err = CopyStringToTable(table.Name, data, s.destination); err != nil { - return + + for _, force := range s.force { + if force.Table == table.Name { + log.Info().Msgf("Selecting forced rows for table %s", table.Name) + var data string + if data, err = CopyQueryToString(force.Where, s.source); err != nil { + return + } + if err = CopyStringToTable(table.Name, data, s.destination); err != nil { + return + } + } } - } - //copy relations, TODO: make this work - for _, table := range subset { for _, relation := range table.Relations { - log.Debug().Msgf("Copying relation %s for table %s", relation, table.Name) + // Backtrace the inserted ids from main table to related table + + log.Info().Msgf("Copying relation %s for table %s", relation, table.Name) + var data string + if data, err = CopyQueryToString(relation.Query(), s.source); err != nil { + return + } + if err = CopyStringToTable(table.Name, data, s.destination); err != nil { + return + } + } + } + return +} + +// Sync copies a subset of tables from source to destination +func (s *Sync) Sync() (err error) { + var tables []Table + if tables, err = GetTablesWithRows(s.source); err != nil { + return + } + + var allTables []Table + if allTables = GetTargetSet(s.fraction, tables); err != nil { + return + } + + subset := ViableSubset(allTables) + if s.verbose { + for _, T := range subset { + log.Info().Msgf("Copying table %s with %d rows", T.Name, T.Rows) } } + if err = s.CopyTables(subset); err != nil { + return + } + return } diff --git a/subsetter/sync_test.go b/subsetter/sync_test.go new file mode 100644 index 0000000..e6fec2e --- /dev/null +++ b/subsetter/sync_test.go @@ -0,0 +1,49 @@ +package subsetter + +import ( + "reflect" + "testing" +) + +func TestViableSubset(t *testing.T) { + tests := []struct { + name string + tables []Table + wantSubset []Table + }{ + {"Simple", []Table{{"simple", 10, []Relation{}}}, []Table{{"simple", 10, []Relation{}}}}, + {"No rows", []Table{{"simple", 0, []Relation{}}}, []Table{}}, + { + "Complex, related tables must be excluded", + []Table{{"simple", 10, []Relation{}}, {"complex", 10, []Relation{{"simple", "id", "complex", "simple_id"}}}}, + []Table{{"simple", 10, []Relation{}}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotSubset := ViableSubset(tt.tables); !reflect.DeepEqual(gotSubset, tt.wantSubset) { + t.Errorf("ViableSubset() = %v, want %v", gotSubset, tt.wantSubset) + } + }) + } +} + +func TestSync_CopyTables(t *testing.T) { + src := getTestConnection() + dst := getTestConnectionDst() + initSchema(src) + initSchema(dst) + defer clearSchema(src) + defer clearSchema(dst) + + s := &Sync{ + source: src, + destination: dst, + } + tables := []Table{{"simple", 10, []Relation{}}} + + if err := s.CopyTables(tables); err != nil { + t.Errorf("Sync.CopyTables() error = %v", err) + } + +}