diff --git a/.gitignore b/.gitignore index 6f0de07..14d5b56 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .pgsql/data bin dist/ +*.sql \ No newline at end of file diff --git a/Makefile b/Makefile index 8771f2c..7356abb 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ is-postgres-running: .PHONY: pgweb pgweb:is-postgres-running - @pgweb --url "postgres://test_source@localhost:5432/test_source?sslmode=disable" + @pgweb --url "postgres://test_target@localhost:5432/test_target?sslmode=disable" build: rm -rf dist diff --git a/README.md b/README.md index fb17185..fccd157 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,8 @@ [![lint](https://github.com/teamniteo/pg-subsetter/actions/workflows/lint.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/lint.yml) [![build](https://github.com/teamniteo/pg-subsetter/actions/workflows/go.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/go.yml) [![vuln](https://github.com/teamniteo/pg-subsetter/actions/workflows/vuln.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/vuln.yml) -`pg-subsetter` is a tool designed to synchronize a fraction of a PostgreSQL database to another PostgreSQL database on the fly, it does not copy the SCHEMA. This means that your target database has to have schema populated in some other way. +`pg-subsetter` is a tool designed to synchronize a fraction of a PostgreSQL database to another PostgreSQL database on the fly, it does not copy the SCHEMA. + ### Database Fraction Synchronization `pg-subsetter` allows you to select and sync a specific subset of your database. Whether it's a fraction of a table or a particular dataset, you can have it replicated in another database without synchronizing the entire DB. @@ -17,6 +18,8 @@ Utilizing the native PostgreSQL COPY command, `pg-subsetter` performs data trans ### Stateless Operation `pg-subsetter` is built to be stateless, meaning it does not maintain any internal state between runs. This ensures that each synchronization process is independent, enhancing reliability and making it easier to manage and scale. +### Sync required rows +`pg-subsetter` can be isntructed to copy certain rows in specific tables, the command can be used multipel times to sync more data. ## Usage @@ -27,7 +30,7 @@ Usage of subsetter: -f float Fraction of rows to copy (default 0.05) -force value - Query to copy required tables (users: id = 1) + Query to copy required tables 'table: whois query', can be used multiple times -src string Source database DSN ``` @@ -35,13 +38,25 @@ Usage of subsetter: ### Example -Copy a fraction of the database and force certain rows to be also copied over. + +Prepare schema in target databse: + +```bash +pg_dump --schema-only -n public -f schemadump.sql "postgres://test_source@localhost:5432/test_source?sslmode=disable" +psql -f schemadump.sql "postgres://test_target@localhost:5432/test_target?sslmode=disable" +``` + +Copy a fraction of the database and force certain rows to be also copied over: ``` pg-subsetter \ -src "postgres://test_source@localhost:5432/test_source?sslmode=disable" \ -dst "postgres://test_target@localhost:5432/test_target?sslmode=disable" \ - -f 0.05 + -f 0.5 + -force "user: id=1" + -force "group: id=1" + -force "domains: domain_name ilike '%.si'" + ``` # Installing diff --git a/cli/main.go b/cli/main.go index cdd6c27..a9c8ab1 100644 --- a/cli/main.go +++ b/cli/main.go @@ -20,7 +20,7 @@ func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack - flag.Var(&forceSync, "force", "Query to copy required tables (users: id = 1)") + flag.Var(&forceSync, "force", "Query to copy required tables 'users: id = 1', can be used multiple times") flag.Parse() if *src == "" || *dst == "" { @@ -37,14 +37,14 @@ func main() { s, err := subsetter.NewSync(*src, *dst, *fraction, forceSync, *verbose) if err != nil { - log.Fatal().Stack().Err(err).Msg("Failed to configure sync") + log.Fatal().Msg("Failed to configure sync") } defer s.Close() err = s.Sync() if err != nil { - log.Fatal().Stack().Err(err).Msg("Failed to sync") + log.Fatal().Msg("Failed to sync") } } diff --git a/flake.lock b/flake.lock index a95c7d6..1231073 100644 --- a/flake.lock +++ b/flake.lock @@ -20,16 +20,16 @@ }, "nixpkgs": { "locked": { - "lastModified": 1691674635, - "narHash": "sha256-dTWUqEf7lb7k67cFZXIG7Xe9ES6XEvKzzUT24A4hGa4=", + "lastModified": 1691709280, + "narHash": "sha256-zmfH2OlZEXwv572d0g8f6M5Ac6RiO8TxymOpY3uuqrM=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "1eef5102c9fcb3281fbf94de90e7d59c92664373", + "rev": "cf73a86c35a84de0e2f3ba494327cf6fb51c0dfd", "type": "github" }, "original": { "owner": "NixOS", - "ref": "release-23.05", + "ref": "nixpkgs-unstable", "repo": "nixpkgs", "type": "github" } diff --git a/flake.nix b/flake.nix index 6367fbc..76b7e50 100644 --- a/flake.nix +++ b/flake.nix @@ -3,7 +3,7 @@ allowed-users = [ "@wheel" "@staff" ]; # allow compiling on every device/machine }; inputs = { - nixpkgs.url = "github:NixOS/nixpkgs/release-23.05"; + nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; flake-parts.url = "github:hercules-ci/flake-parts"; }; outputs = inputs@{ self, nixpkgs, flake-parts, ... }: @@ -35,7 +35,7 @@ go goreleaser golangci-lint - postgresql + postgresql_15 process-compose nixpkgs-fmt pgweb diff --git a/go.mod b/go.mod index 105b568..847466a 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( ) require ( + github.com/davecgh/go-spew v1.1.1 github.com/jackc/pgx/v5 v5.4.3 github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect diff --git a/go.sum b/go.sum index 95de6cf..131fcde 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,7 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= diff --git a/subsetter/query.go b/subsetter/query.go index 546d45d..d340480 100644 --- a/subsetter/query.go +++ b/subsetter/query.go @@ -4,8 +4,10 @@ import ( "bytes" "context" "fmt" + "strings" "github.com/jackc/pgx/v5/pgxpool" + "github.com/samber/lo" ) type Table struct { @@ -14,6 +16,16 @@ type Table struct { Relations []Relation } +func (t *Table) RelationNames() (names string) { + rel := lo.Map(t.Relations, func(r Relation, _ int) string { + return r.ForeignTable + ">" + r.ForeignColumn + }) + if len(rel) > 0 { + return strings.Join(rel, ", ") + } + return "none" +} + func GetTablesWithRows(conn *pgxpool.Pool) (tables []Table, err error) { q := `SELECT relname, @@ -53,10 +65,12 @@ func CopyQueryToString(query string, conn *pgxpool.Pool) (result string, err err if err != nil { return } + defer c.Release() if _, err = c.Conn().PgConn().CopyTo(context.Background(), &buff, q); err != nil { return } result = buff.String() + return } @@ -73,6 +87,8 @@ func CopyStringToTable(table string, data string, conn *pgxpool.Pool) (err error if err != nil { return } + defer c.Release() + if _, err = c.Conn().PgConn().CopyFrom(context.Background(), &buff, q); err != nil { return } diff --git a/subsetter/relations.go b/subsetter/relations.go index 3d027d7..2fda3f2 100644 --- a/subsetter/relations.go +++ b/subsetter/relations.go @@ -3,8 +3,10 @@ package subsetter import ( "context" "fmt" + "regexp" "github.com/jackc/pgx/v5/pgxpool" + "github.com/samber/lo" ) type Relation struct { @@ -18,40 +20,57 @@ func (r *Relation) Query() string { return fmt.Sprintf(`SELECT * FROM %s WHERE %s IN (SELECT %s FROM %s)`, r.ForeignTable, r.ForeignColumn, r.PrimaryColumn, r.PrimaryTable) } +type RelationInfo struct { + TableName string + ForeignTable string + SQL string +} + +func (r *RelationInfo) toRelation() Relation { + var rel Relation + re := regexp.MustCompile(`FOREIGN KEY \((\w+)\) REFERENCES (\w+)\((\w+)\).*`) + matches := re.FindStringSubmatch(r.SQL) + if len(matches) == 4 { + rel.PrimaryColumn = matches[1] + rel.ForeignTable = matches[2] + rel.ForeignColumn = matches[3] + } + rel.PrimaryTable = r.TableName + return rel +} + // GetRelations returns a list of tables that have a foreign key for particular table. func GetRelations(table string, conn *pgxpool.Pool) (relations []Relation, err error) { q := `SELECT - kcu.table_name AS foreign_table_name, - kcu.column_name AS foreign_column_name, - ccu.table_name, - ccu.column_name + conrelid::regclass AS table_name, + confrelid::regclass AS refrerenced_table, + pg_get_constraintdef(c.oid, TRUE) AS sql FROM - information_schema.table_constraints tc - JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema - JOIN information_schema.referential_constraints rc ON tc.constraint_name = rc.constraint_name - AND tc.table_schema = rc.constraint_schema - JOIN information_schema.constraint_column_usage ccu ON rc.unique_constraint_name = ccu.constraint_name + pg_constraint c + JOIN pg_namespace n ON n.oid = c.connamespace WHERE - tc.constraint_type = 'FOREIGN KEY' - AND ccu.table_name = $1 - AND tc.table_schema = 'public';` + c.contype = 'f' + AND n.nspname = 'public';` - rows, err := conn.Query(context.Background(), q, table) + rows, err := conn.Query(context.Background(), q) if err != nil { return } defer rows.Close() for rows.Next() { - var rel Relation - err = rows.Scan(&rel.ForeignTable, &rel.ForeignColumn, &rel.PrimaryTable, &rel.PrimaryColumn) + var rel RelationInfo + + err = rows.Scan(&rel.TableName, &rel.ForeignTable, &rel.SQL) if err != nil { return } - relations = append(relations, rel) + relations = append(relations, rel.toRelation()) } + relations = lo.Filter(relations, func(rel Relation, _ int) bool { + return rel.PrimaryTable == table + }) return } diff --git a/subsetter/relations_test.go b/subsetter/relations_test.go index b1471c4..2d717e8 100644 --- a/subsetter/relations_test.go +++ b/subsetter/relations_test.go @@ -4,6 +4,7 @@ import ( "reflect" "testing" + "github.com/davecgh/go-spew/spew" "github.com/jackc/pgx/v5/pgxpool" ) @@ -17,7 +18,7 @@ func TestGetRelations(t *testing.T) { conn *pgxpool.Pool wantRelations []Relation }{ - {"With relation", "simple", conn, []Relation{{"simple", "id", "relation", "simple_id"}}}, + {"With relation", "relation", conn, []Relation{{"relation", "simple_id", "simple", "id"}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -44,3 +45,35 @@ func TestRelation_Query(t *testing.T) { }) } } + +func TestRelationInfo_toRelation(t *testing.T) { + + tests := []struct { + name string + fields RelationInfo + want Relation + }{ + { + "Simple", + RelationInfo{"relation", "simple", "FOREIGN KEY (simple_id) REFERENCES simple(id)"}, + Relation{"relation", "simple_id", "simple", "id"}, + }, + { + "Simple with cascade", + RelationInfo{"relation", "simple", "FOREIGN KEY (simple_id) REFERENCES simple(id) ON DELETE CASCADE"}, + Relation{"relation", "simple_id", "simple", "id"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &RelationInfo{ + TableName: tt.fields.TableName, + ForeignTable: tt.fields.ForeignTable, + SQL: tt.fields.SQL, + } + if got := r.toRelation(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("RelationInfo.toRelation() = %v, want %v", spew.Sdump(got), spew.Sdump(tt.want)) + } + }) + } +} diff --git a/subsetter/sync.go b/subsetter/sync.go index 1eb95d8..342e5e3 100644 --- a/subsetter/sync.go +++ b/subsetter/sync.go @@ -2,6 +2,7 @@ package subsetter import ( "context" + "sort" "github.com/jackc/pgx/v5/pgxpool" "github.com/rs/zerolog/log" @@ -32,7 +33,7 @@ func NewSync(source string, target string, fraction float64, force []Force, verb return nil, err } - dst, err := pgxpool.New(context.Background(), source) + dst, err := pgxpool.New(context.Background(), target) if err != nil { return nil, err } @@ -60,9 +61,11 @@ func (s *Sync) Close() { func copyTableData(table Table, source *pgxpool.Pool, destination *pgxpool.Pool) (err error) { var data string if data, err = CopyTableToString(table.Name, table.Rows, source); err != nil { + log.Error().Err(err).Msgf("Error copying table %s", table.Name) return } if err = CopyStringToTable(table.Name, data, destination); err != nil { + log.Error().Err(err).Msgf("Error pasting table %s", table.Name) return } return @@ -73,17 +76,9 @@ func ViableSubset(tables []Table) (subset []Table) { // Filter out tables with no rows subset = lo.Filter(tables, func(table Table, _ int) bool { return table.Rows > 0 }) - - // Get all relations - relationsR := lo.FlatMap(subset, func(table Table, _ int) []Relation { return table.Relations }) - relations := lo.Map(relationsR, func(relation Relation, _ int) string { return relation.ForeignTable }) - - // Filter out tables that are relations of other tables - // they will be copied later - subset = lo.Filter(subset, func(table Table, _ int) bool { - return !lo.Contains(relations, table.Name) + sort.Slice(subset, func(i, j int) bool { + return len(subset[i].Relations) < len(subset[j].Relations) }) - return } @@ -140,8 +135,11 @@ func (s *Sync) Sync() (err error) { if s.verbose { for _, t := range subset { - log.Info().Msgf("Copying table %s with %d rows", t.Name, t.Rows) - log.Info().Msgf("Relations: %v", t.Relations) + log.Info(). + Str("table", t.Name). + Int("rows", t.Rows). + Str("related", t.RelationNames()). + Msg("Prepared for sync") } } diff --git a/subsetter/sync_test.go b/subsetter/sync_test.go index 67e85a3..1695a62 100644 --- a/subsetter/sync_test.go +++ b/subsetter/sync_test.go @@ -1,17 +1,10 @@ package subsetter import ( - "os" "reflect" "testing" ) -func skipCI(t *testing.T) { - if os.Getenv("CI") != "" { - t.Skip("Skipping testing in CI environment") - } -} - func TestViableSubset(t *testing.T) { tests := []struct { name string @@ -43,7 +36,6 @@ func TestViableSubset(t *testing.T) { } func TestSync_CopyTables(t *testing.T) { - skipCI(t) src := getTestConnection() dst := getTestConnectionDst() initSchema(src)