diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 3ff64fa..959e46d 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -10,8 +10,6 @@ on: branches: ["main"] jobs: - vuln: - uses: cristalhq/.github/.github/workflows/vuln.yml@v0.4.0 build: runs-on: ubuntu-latest @@ -48,7 +46,7 @@ jobs: run: make build - name: Test - run: go test -v ./... + run: go test -timeout 30s -v ./... - name: Upload assets uses: actions/upload-artifact@v3 diff --git a/.github/workflows/vuln.yml b/.github/workflows/vuln.yml new file mode 100644 index 0000000..befeab5 --- /dev/null +++ b/.github/workflows/vuln.yml @@ -0,0 +1,11 @@ +name: vuln + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +jobs: + vuln: + uses: cristalhq/.github/.github/workflows/vuln.yml@v0.4.0 diff --git a/README.md b/README.md index 8751565..fb17185 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ # pg-subsetter -[![lint](https://github.com/teamniteo/pg-subsetter/actions/workflows/lint.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/lint.yml) [![build](https://github.com/teamniteo/pg-subsetter/actions/workflows/go.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/go.yml) +[![lint](https://github.com/teamniteo/pg-subsetter/actions/workflows/lint.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/lint.yml) [![build](https://github.com/teamniteo/pg-subsetter/actions/workflows/go.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/go.yml) [![vuln](https://github.com/teamniteo/pg-subsetter/actions/workflows/vuln.yml/badge.svg)](https://github.com/teamniteo/pg-subsetter/actions/workflows/vuln.yml) -`pg-subsetter` is a powerful and efficient tool designed to synchronize a fraction of a PostgreSQL database to another PostgreSQL database on the fly, it does not copy the SCHEMA, this means that your target database has to have schema populated in some other way. +`pg-subsetter` is a tool designed to synchronize a fraction of a PostgreSQL database to another PostgreSQL database on the fly, it does not copy the SCHEMA. This means that your target database has to have schema populated in some other way. ### Database Fraction Synchronization `pg-subsetter` allows you to select and sync a specific subset of your database. Whether it's a fraction of a table or a particular dataset, you can have it replicated in another database without synchronizing the entire DB. @@ -12,7 +12,7 @@ Foreign keys play a vital role in maintaining the relationships between tables. `pg-subsetter` ensures that all foreign keys are handled correctly during the synchronization process, maintaining the integrity and relationships of the data. ### Efficient COPY Method -Utilizing the native PostgreSQL COPY command, pg-subsetter performs data transfer with high efficiency. This method significantly speeds up the synchronization process, minimizing downtime and resource consumption. +Utilizing the native PostgreSQL COPY command, `pg-subsetter` performs data transfer with high efficiency. This method significantly speeds up the synchronization process, minimizing downtime and resource consumption. ### Stateless Operation `pg-subsetter` is built to be stateless, meaning it does not maintain any internal state between runs. This ensures that each synchronization process is independent, enhancing reliability and making it easier to manage and scale. @@ -23,17 +23,26 @@ Utilizing the native PostgreSQL COPY command, pg-subsetter performs data transfe ``` Usage of subsetter: -dst string - Destination DSN + Destination database DSN -f float Fraction of rows to copy (default 0.05) + -force value + Query to copy required tables (users: id = 1) -src string - Source DSN + Source database DSN ``` -Example: +### Example -```pg-subsetter -src postgresql://:@/bigdb -dst postgresql://:@/littledb -f 0.05``` +Copy a fraction of the database and force certain rows to be also copied over. + +``` +pg-subsetter \ + -src "postgres://test_source@localhost:5432/test_source?sslmode=disable" \ + -dst "postgres://test_target@localhost:5432/test_target?sslmode=disable" \ + -f 0.05 +``` # Installing diff --git a/cli/force.go b/cli/force.go new file mode 100644 index 0000000..56a4635 --- /dev/null +++ b/cli/force.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "strings" + + "niteo.co/subsetter/subsetter" +) + +type arrayForce []subsetter.Force + +func (i *arrayForce) String() string { + return fmt.Sprintf("%v", *i) +} + +func (i *arrayForce) Set(value string) error { + q := strings.SplitAfter(strings.TrimSpace(value), ":") + + *i = append(*i, subsetter.Force{ + Table: q[0], + Where: q[1], + }) + return nil +} diff --git a/cli/main.go b/cli/main.go index dd07c59..cdd6c27 100644 --- a/cli/main.go +++ b/cli/main.go @@ -2,17 +2,26 @@ package main import ( "flag" + "os" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "github.com/rs/zerolog/pkgerrors" + "niteo.co/subsetter/subsetter" ) var src = flag.String("src", "", "Source database DSN") var dst = flag.String("dst", "", "Destination database DSN") var fraction = flag.Float64("f", 0.05, "Fraction of rows to copy") +var verbose = flag.Bool("verbose", true, "Show more information during sync") +var forceSync arrayForce func main() { + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack + + flag.Var(&forceSync, "force", "Query to copy required tables (users: id = 1)") flag.Parse() - log.Info().Msg("Starting") if *src == "" || *dst == "" { log.Fatal().Msg("Source and destination DSNs are required") @@ -22,4 +31,20 @@ func main() { log.Fatal().Msg("Fraction must be between 0 and 1") } + if len(forceSync) > 0 { + log.Info().Str("forced", forceSync.String()).Msg("Forcing sync for tables") + } + + s, err := subsetter.NewSync(*src, *dst, *fraction, forceSync, *verbose) + if err != nil { + log.Fatal().Stack().Err(err).Msg("Failed to configure sync") + } + + defer s.Close() + + err = s.Sync() + if err != nil { + log.Fatal().Stack().Err(err).Msg("Failed to sync") + } + } diff --git a/flake.nix b/flake.nix index 29396f0..6367fbc 100644 --- a/flake.nix +++ b/flake.nix @@ -37,7 +37,6 @@ golangci-lint postgresql process-compose - shellcheck nixpkgs-fmt pgweb ]; diff --git a/go.mod b/go.mod index 8c953fa..105b568 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,10 @@ require github.com/rs/zerolog v1.30.0 require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/pkg/errors v0.9.1 // indirect golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect + golang.org/x/sync v0.1.0 // indirect ) require ( diff --git a/go.sum b/go.sum index d87d999..95de6cf 100644 --- a/go.sum +++ b/go.sum @@ -8,10 +8,13 @@ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -28,6 +31,8 @@ golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= diff --git a/subsetter/db_test.go b/subsetter/db_test.go index 5c6ecea..7033e38 100644 --- a/subsetter/db_test.go +++ b/subsetter/db_test.go @@ -4,24 +4,65 @@ import ( "context" "fmt" "os" + "sync" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) -func getTestConnection() *pgx.Conn { +var testConnSrc *pgxpool.Pool +var onceTestSrc sync.Once +var testConnTrg *pgxpool.Pool +var onceTestTrg sync.Once + +func getTestConnection() *pgxpool.Pool { DATABASE_URL := os.Getenv("DATABASE_URL") if DATABASE_URL == "" { DATABASE_URL = "postgres://test_source@localhost:5432/test_source?sslmode=disable" } - conn, err := pgx.Connect(context.Background(), DATABASE_URL) - if err != nil { - panic(err) + onceTestSrc.Do(func() { + c, err := pgxpool.New(context.Background(), DATABASE_URL) + if err != nil { + panic(err) + } + testConnSrc = c + }) + + return testConnSrc +} + +func getRealConnection(url string) *pgxpool.Pool { + + onceTestSrc.Do(func() { + c, err := pgxpool.New(context.Background(), url) + if err != nil { + panic(err) + } + testConnSrc = c + }) + + return testConnSrc +} + +func getTestConnectionDst() *pgxpool.Pool { + DATABASE_URL := os.Getenv("DATABASE_URL") + if DATABASE_URL == "" { + DATABASE_URL = "postgres://test_target@localhost:5432/test_target?sslmode=disable" } - return conn + + onceTestTrg.Do(func() { + c, err := pgxpool.New(context.Background(), DATABASE_URL) + if err != nil { + panic(err) + } + testConnTrg = c + }) + + return testConnTrg } -func populateTests(conn *pgx.Conn) { +func initSchema(conn *pgxpool.Pool) { + _, err := conn.Exec(context.Background(), ` CREATE TABLE simple ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), @@ -41,7 +82,7 @@ func populateTests(conn *pgx.Conn) { } } -func populateTestsWithData(conn *pgx.Conn, table string, size int) { +func populateTestsWithData(conn *pgxpool.Pool, table string, size int) { for i := 0; i < size; i++ { query := fmt.Sprintf("INSERT INTO %s (text) VALUES ('test%d') RETURNING id", table, i) var row string @@ -58,7 +99,7 @@ func populateTestsWithData(conn *pgx.Conn, table string, size int) { } } -func clearPopulateTests(conn *pgx.Conn) { +func clearSchema(conn *pgxpool.Pool) { _, err := conn.Exec(context.Background(), ` ALTER TABLE relation DROP CONSTRAINT relation_simple_fk; DROP TABLE simple; diff --git a/subsetter/info.go b/subsetter/info.go index 2025c99..6ee385f 100644 --- a/subsetter/info.go +++ b/subsetter/info.go @@ -8,8 +8,9 @@ func GetTargetSet(fraction float64, tables []Table) []Table { for _, table := range tables { subset = append(subset, Table{ - Name: table.Name, - Rows: int(math.Pow(10, math.Log10(float64(table.Rows))*fraction)), + Name: table.Name, + Rows: int(math.Pow(10, math.Log10(float64(table.Rows))*fraction)), + Relations: table.Relations, }) } diff --git a/subsetter/info_test.go b/subsetter/info_test.go index 28d843d..edc88de 100644 --- a/subsetter/info_test.go +++ b/subsetter/info_test.go @@ -1,15 +1,13 @@ package subsetter import ( - "context" "testing" ) func TestGetTargetSet(t *testing.T) { conn := getTestConnection() - populateTests(conn) - defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + initSchema(conn) + defer clearSchema(conn) tests := []struct { name string @@ -17,8 +15,8 @@ func TestGetTargetSet(t *testing.T) { tables []Table want []Table }{ - {"simple", 0.5, []Table{{"simple", 1000, []string{}}}, []Table{{"simple", 31, []string{}}}}, - {"simple", 0.5, []Table{{"simple", 10, []string{}}}, []Table{{"simple", 3, []string{}}}}, + {"simple", 0.5, []Table{{"simple", 1000, []Relation{}}}, []Table{{"simple", 31, []Relation{}}}}, + {"simple", 0.5, []Table{{"simple", 10, []Relation{}}}, []Table{{"simple", 3, []Relation{}}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/subsetter/query.go b/subsetter/query.go index 4bd5856..546d45d 100644 --- a/subsetter/query.go +++ b/subsetter/query.go @@ -5,30 +5,25 @@ import ( "context" "fmt" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) type Table struct { Name string Rows int - Relations []string + Relations []Relation } -func GetTables(conn *pgx.Conn) (tables []string, err error) { - q := `SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';` - rows, err := conn.Query(context.Background(), q) - for rows.Next() { - var name string - if err := rows.Scan(&name); err == nil { - tables = append(tables, name) - } - } - rows.Close() - return -} - -func GetTablesWithRows(conn *pgx.Conn) (tables []Table, err error) { - q := `SELECT relname, reltuples::int FROM pg_class,information_schema.tables WHERE table_schema = 'public' AND relname = table_name;` +func GetTablesWithRows(conn *pgxpool.Pool) (tables []Table, err error) { + q := `SELECT + relname, + reltuples::int + FROM + pg_class, + information_schema.tables + WHERE + table_schema = 'public' + AND relname = table_name;` rows, err := conn.Query(context.Background(), q) for rows.Next() { var table Table @@ -38,6 +33,10 @@ func GetTablesWithRows(conn *pgx.Conn) (tables []Table, err error) { if table.Rows == -1 { table.Rows = 0 } + table.Relations, err = GetRelations(table.Name, conn) + if err != nil { + return nil, err + } tables = append(tables, table) } @@ -47,21 +46,34 @@ func GetTablesWithRows(conn *pgx.Conn) (tables []Table, err error) { return } -func CopyTableToString(table string, limit int, conn *pgx.Conn) (result string, err error) { - q := fmt.Sprintf(`copy (SELECT * FROM %s order by random() limit %d) to stdout`, table, limit) +func CopyQueryToString(query string, conn *pgxpool.Pool) (result string, err error) { + q := fmt.Sprintf(`copy (%s) to stdout`, query) var buff bytes.Buffer - if _, err = conn.PgConn().CopyFrom(context.Background(), &buff, q); err != nil { + c, err := conn.Acquire(context.Background()) + if err != nil { + return + } + if _, err = c.Conn().PgConn().CopyTo(context.Background(), &buff, q); err != nil { return } result = buff.String() return } -func CopyStringToTable(table string, data string, conn *pgx.Conn) (err error) { +func CopyTableToString(table string, limit int, conn *pgxpool.Pool) (result string, err error) { + q := fmt.Sprintf(`SELECT * FROM %s order by random() limit %d`, table, limit) + return CopyQueryToString(q, conn) +} + +func CopyStringToTable(table string, data string, conn *pgxpool.Pool) (err error) { q := fmt.Sprintf(`copy %s from stdin`, table) var buff bytes.Buffer buff.WriteString(data) - if _, err = conn.PgConn().CopyFrom(context.Background(), &buff, q); err != nil { + c, err := conn.Acquire(context.Background()) + if err != nil { + return + } + if _, err = c.Conn().PgConn().CopyFrom(context.Background(), &buff, q); err != nil { return } diff --git a/subsetter/query_test.go b/subsetter/query_test.go index 0bb4f58..ff067f7 100644 --- a/subsetter/query_test.go +++ b/subsetter/query_test.go @@ -2,46 +2,23 @@ package subsetter import ( "context" - "reflect" "strings" "testing" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) -func TestGetTables(t *testing.T) { - conn := getTestConnection() - populateTests(conn) - defer conn.Close(context.Background()) - defer clearPopulateTests(conn) - tests := []struct { - name string - conn *pgx.Conn - wantTables []string - }{ - {"With tables", conn, []string{"simple", "relation"}}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if gotTables, _ := GetTables(tt.conn); !reflect.DeepEqual(gotTables, tt.wantTables) { - t.Errorf("GetTables() = %v, want %v", gotTables, tt.wantTables) - } - }) - } -} - func TestGetTablesWithRows(t *testing.T) { conn := getTestConnection() - populateTests(conn) - defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + initSchema(conn) + defer clearSchema(conn) tests := []struct { name string - conn *pgx.Conn + conn *pgxpool.Pool wantTables []Table wantErr bool }{ - {"With tables", conn, []Table{{"simple", 0, []string{}}, {"relation", 0, []string{}}}, false}, + {"With tables", conn, []Table{{"simple", 0, []Relation{}}, {"relation", 0, []Relation{}}}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -62,15 +39,14 @@ func TestGetTablesWithRows(t *testing.T) { func TestCopyTableToString(t *testing.T) { conn := getTestConnection() - populateTests(conn) - defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + initSchema(conn) + defer clearSchema(conn) populateTestsWithData(conn, "simple", 10) tests := []struct { name string table string - conn *pgx.Conn + conn *pgxpool.Pool wantResult bool wantErr bool }{ @@ -83,7 +59,7 @@ func TestCopyTableToString(t *testing.T) { t.Errorf("CopyTableToString() error = %v, wantErr %v", err, tt.wantErr) return } - if strings.Contains(gotResult, "test") == tt.wantResult { + if strings.Contains(gotResult, "test") != tt.wantResult { t.Errorf("CopyTableToString() = %v, want %v", gotResult, tt.wantResult) } }) @@ -92,15 +68,14 @@ func TestCopyTableToString(t *testing.T) { func TestCopyStringToTable(t *testing.T) { conn := getTestConnection() - populateTests(conn) - defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + initSchema(conn) + defer clearSchema(conn) tests := []struct { name string table string data string - conn *pgx.Conn + conn *pgxpool.Pool wantResult int wantErr bool }{ @@ -123,7 +98,7 @@ func TestCopyStringToTable(t *testing.T) { } } -func insertedRows(s string, conn *pgx.Conn) int { +func insertedRows(s string, conn *pgxpool.Pool) int { q := "SELECT count(*) FROM " + s var count int err := conn.QueryRow(context.Background(), q).Scan(&count) diff --git a/subsetter/relations.go b/subsetter/relations.go index ea96639..3d027d7 100644 --- a/subsetter/relations.go +++ b/subsetter/relations.go @@ -2,36 +2,56 @@ package subsetter import ( "context" + "fmt" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) type Relation struct { - Table string - Column string + PrimaryTable string + PrimaryColumn string + ForeignTable string + ForeignColumn string +} + +func (r *Relation) Query() string { + return fmt.Sprintf(`SELECT * FROM %s WHERE %s IN (SELECT %s FROM %s)`, r.ForeignTable, r.ForeignColumn, r.PrimaryColumn, r.PrimaryTable) } // GetRelations returns a list of tables that have a foreign key for particular table. -func GetRelations(table string, conn *pgx.Conn) (relations []Relation, err error) { +func GetRelations(table string, conn *pgxpool.Pool) (relations []Relation, err error) { q := `SELECT - kcu.table_name, - kcu.column_name + kcu.table_name AS foreign_table_name, + kcu.column_name AS foreign_column_name, + ccu.table_name, + ccu.column_name FROM - information_schema.table_constraints AS tc - JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name - JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name + information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.referential_constraints rc ON tc.constraint_name = rc.constraint_name + AND tc.table_schema = rc.constraint_schema + JOIN information_schema.constraint_column_usage ccu ON rc.unique_constraint_name = ccu.constraint_name WHERE tc.constraint_type = 'FOREIGN KEY' - AND ccu.table_name = $1;` + AND ccu.table_name = $1 + AND tc.table_schema = 'public';` rows, err := conn.Query(context.Background(), q, table) + if err != nil { + return + } + defer rows.Close() + for rows.Next() { var rel Relation - if err := rows.Scan(&rel.Table, &rel.Column); err == nil { - relations = append(relations, rel) + err = rows.Scan(&rel.ForeignTable, &rel.ForeignColumn, &rel.PrimaryTable, &rel.PrimaryColumn) + if err != nil { + return } + relations = append(relations, rel) } - rows.Close() + return } diff --git a/subsetter/relations_test.go b/subsetter/relations_test.go index bd42eb0..b1471c4 100644 --- a/subsetter/relations_test.go +++ b/subsetter/relations_test.go @@ -1,25 +1,23 @@ package subsetter import ( - "context" "reflect" "testing" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) func TestGetRelations(t *testing.T) { conn := getTestConnection() - populateTests(conn) - defer conn.Close(context.Background()) - defer clearPopulateTests(conn) + initSchema(conn) + defer clearSchema(conn) tests := []struct { name string table string - conn *pgx.Conn + conn *pgxpool.Pool wantRelations []Relation }{ - {"With relation", "simple", conn, []Relation{{"relation", "simple_id"}}}, + {"With relation", "simple", conn, []Relation{{"simple", "id", "relation", "simple_id"}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -29,3 +27,20 @@ func TestGetRelations(t *testing.T) { }) } } + +func TestRelation_Query(t *testing.T) { + tests := []struct { + name string + r Relation + want string + }{ + {"Simple", Relation{"simple", "id", "relation", "simple_id"}, "SELECT * FROM relation WHERE simple_id IN (SELECT id FROM simple)"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.Query(); got != tt.want { + t.Errorf("Relation.Query() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/subsetter/sync.go b/subsetter/sync.go index f8dbdd4..1eb95d8 100644 --- a/subsetter/sync.go +++ b/subsetter/sync.go @@ -3,24 +3,40 @@ package subsetter import ( "context" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "github.com/rs/zerolog/log" "github.com/samber/lo" ) type Sync struct { - source *pgx.Conn - destination *pgx.Conn + source *pgxpool.Pool + destination *pgxpool.Pool fraction float64 verbose bool + force []Force } -func NewSync(source string, target string, fraction float64, verbose bool) (*Sync, error) { - src, err := pgx.Connect(context.Background(), source) +type Force struct { + Table string + Where string +} + +func NewSync(source string, target string, fraction float64, force []Force, verbose bool) (*Sync, error) { + src, err := pgxpool.New(context.Background(), source) + if err != nil { + return nil, err + } + + err = src.Ping(context.Background()) if err != nil { return nil, err } - dst, err := pgx.Connect(context.Background(), source) + + dst, err := pgxpool.New(context.Background(), source) + if err != nil { + return nil, err + } + err = dst.Ping(context.Background()) if err != nil { return nil, err } @@ -30,24 +46,37 @@ func NewSync(source string, target string, fraction float64, verbose bool) (*Syn destination: dst, fraction: fraction, verbose: verbose, + force: force, }, nil } -func (s *Sync) Sync() (err error) { - var tables []Table - if tables, err = GetTablesWithRows(s.source); err != nil { +// Close closes the connections to the source and destination databases +func (s *Sync) Close() { + s.source.Close() + s.destination.Close() +} + +// copyTableData copies the data from a table in the source database to the destination database +func copyTableData(table Table, source *pgxpool.Pool, destination *pgxpool.Pool) (err error) { + var data string + if data, err = CopyTableToString(table.Name, table.Rows, source); err != nil { return } - - var subset []Table - if subset = GetTargetSet(s.fraction, tables); err != nil { + if err = CopyStringToTable(table.Name, data, destination); err != nil { return } + return +} + +// ViableSubset returns a subset of tables that can be copied to the destination database +func ViableSubset(tables []Table) (subset []Table) { + // Filter out tables with no rows - subset = lo.Filter(subset, func(table Table, _ int) bool { return table.Rows > 0 }) + subset = lo.Filter(tables, func(table Table, _ int) bool { return table.Rows > 0 }) - // Generate a list of relations that should be excluded from the subset - relations := lo.FlatMap(subset, func(table Table, _ int) []string { return table.Relations }) + // Get all relations + relationsR := lo.FlatMap(subset, func(table Table, _ int) []Relation { return table.Relations }) + relations := lo.Map(relationsR, func(relation Relation, _ int) string { return relation.ForeignTable }) // Filter out tables that are relations of other tables // they will be copied later @@ -55,22 +84,71 @@ func (s *Sync) Sync() (err error) { return !lo.Contains(relations, table.Name) }) - for _, table := range subset { - var data string - if data, err = CopyTableToString(table.Name, table.Rows, s.source); err != nil { + return +} + +// CopyTables copies the data from a list of tables in the source database to the destination database +func (s *Sync) CopyTables(tables []Table) (err error) { + for _, table := range tables { + log.Info().Msgf("Copying table %s", table.Name) + if err = copyTableData(table, s.source, s.destination); err != nil { return } - if err = CopyStringToTable(table.Name, data, s.destination); err != nil { - return + + for _, force := range s.force { + if force.Table == table.Name { + log.Info().Msgf("Selecting forced rows for table %s", table.Name) + var data string + if data, err = CopyQueryToString(force.Where, s.source); err != nil { + return + } + if err = CopyStringToTable(table.Name, data, s.destination); err != nil { + return + } + } } - } - //copy relations, TODO: make this work - for _, table := range subset { for _, relation := range table.Relations { - log.Debug().Msgf("Copying relation %s for table %s", relation, table.Name) + // Backtrace the inserted ids from main table to related table + + log.Info().Msgf("Copying relation %s for table %s", relation, table.Name) + var data string + if data, err = CopyQueryToString(relation.Query(), s.source); err != nil { + return + } + if err = CopyStringToTable(table.Name, data, s.destination); err != nil { + return + } } } + return +} + +// Sync copies a subset of tables from source to destination +func (s *Sync) Sync() (err error) { + var tables []Table + if tables, err = GetTablesWithRows(s.source); err != nil { + return + } + + var allTables []Table + if allTables = GetTargetSet(s.fraction, tables); err != nil { + return + } + + subset := ViableSubset(allTables) + + if s.verbose { + for _, t := range subset { + log.Info().Msgf("Copying table %s with %d rows", t.Name, t.Rows) + log.Info().Msgf("Relations: %v", t.Relations) + + } + } + + if err = s.CopyTables(subset); err != nil { + return + } return } diff --git a/subsetter/sync_test.go b/subsetter/sync_test.go new file mode 100644 index 0000000..67e85a3 --- /dev/null +++ b/subsetter/sync_test.go @@ -0,0 +1,64 @@ +package subsetter + +import ( + "os" + "reflect" + "testing" +) + +func skipCI(t *testing.T) { + if os.Getenv("CI") != "" { + t.Skip("Skipping testing in CI environment") + } +} + +func TestViableSubset(t *testing.T) { + tests := []struct { + name string + tables []Table + wantSubset []Table + }{ + { + "Simple", + []Table{{"simple", 10, []Relation{}}}, + []Table{{"simple", 10, []Relation{}}}, + }, + { + "No rows", + []Table{{"simple", 0, []Relation{}}}, + []Table{}}, + { + "Complex, related tables must be excluded", + []Table{{"simple", 10, []Relation{}}, {"complex", 10, []Relation{{"simple", "id", "complex", "simple_id"}}}}, + []Table{{"simple", 10, []Relation{}}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotSubset := ViableSubset(tt.tables); !reflect.DeepEqual(gotSubset, tt.wantSubset) { + t.Errorf("ViableSubset() = %v, want %v", gotSubset, tt.wantSubset) + } + }) + } +} + +func TestSync_CopyTables(t *testing.T) { + skipCI(t) + src := getTestConnection() + dst := getTestConnectionDst() + initSchema(src) + initSchema(dst) + defer clearSchema(src) + defer clearSchema(dst) + + s := &Sync{ + source: src, + destination: dst, + } + tables := []Table{{"simple", 10, []Relation{}}} + + if err := s.CopyTables(tables); err != nil { + t.Errorf("Sync.CopyTables() error = %v", err) + } + +}