diff --git a/Makefile b/Makefile index c237dd6..28c23cb 100644 --- a/Makefile +++ b/Makefile @@ -13,8 +13,12 @@ is-postgres-running: @(pg_isready -h localhost) || (echo "# ==> Startis postgres by running 'make up'" && exit 2) .PHONY: pgweb -pgweb:is-postgres-running - @pgweb --url "postgres://test_target@localhost:5432/test_target?sslmode=disable" +pgweb: is-postgres-running + if [ -n "$(filter-out $@,$(MAKECMDGOALS))" ]; then \ + pgweb --url "$(filter-out $@,$(MAKECMDGOALS))"; \ + else \ + pgweb --url "postgres://test_target@localhost:5432/test_target?sslmode=disable"; \ + fi; build: rm -rf dist @@ -24,7 +28,7 @@ lint: golangci-lint run dump: - pg_dump --no-acl --schema-only -n public -x -O -f ./dump.sql $(filter-out $@,$(MAKECMDGOALS)) + pg_dump --no-acl --schema-only -n public -x -O -c -f ./dump.sql $(filter-out $@,$(MAKECMDGOALS)) restore: psql -f ./dump.sql "postgres://test_target@localhost:5432/test_target?sslmode=disable" diff --git a/cli/extra.go b/cli/extra.go index 113efed..e217000 100644 --- a/cli/extra.go +++ b/cli/extra.go @@ -25,7 +25,7 @@ func (i *arrayExtra) Set(value string) error { func maybeAll(s string) string { if s == "all" { - return "1=1" + return subsetter.RuleAll } return s } diff --git a/devenv.nix b/devenv.nix index 1feb94a..2afa6f0 100644 --- a/devenv.nix +++ b/devenv.nix @@ -14,6 +14,7 @@ eclint # EditorConfig linter and fixer gnumake # GNU Make goreleaser # Go binary release tool + pgweb # PostgreSQL web interface ]; languages.javascript.enable = true; diff --git a/subsetter/copy.go b/subsetter/copy.go new file mode 100644 index 0000000..b8eab2b --- /dev/null +++ b/subsetter/copy.go @@ -0,0 +1,141 @@ +package subsetter + +import ( + "fmt" + "strings" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/pkg/errors" + "github.com/rs/zerolog/log" + "github.com/samber/lo" +) + +// copyTableData copies the data from a table in the source database to the destination database +func copyTableData(table Table, relatedQueries []string, withLimit bool, source *pgxpool.Pool, destination *pgxpool.Pool) (err error) { + // Backtrace the inserted ids from main table to related table + subSelectQuery := "" + if len(relatedQueries) > 0 { + subSelectQuery = "WHERE " + strings.Join(relatedQueries, " AND ") + } + + limit := "" + if withLimit { + limit = fmt.Sprintf("LIMIT %d", table.Rows) + } + + var data string + if data, err = CopyTableToString(table.Name, limit, subSelectQuery, source); err != nil { + //log.Error().Err(err).Str("table", table.Name).Msg("Error getting table data") + return + } + if err = CopyStringToTable(table.Name, data, destination); err != nil { + //log.Error().Err(err).Str("table", table.Name).Msg("Error pushing table data") + return + } + return + +} + +func relatedQueriesBuilder( + depth *int, + tables []Table, + relation Relation, + table Table, + source *pgxpool.Pool, + destination *pgxpool.Pool, + visitedTables *[]string, + relatedQueries *[]string, +) (err error) { + +retry: + q := fmt.Sprintf(`SELECT %s FROM %s`, relation.ForeignColumn, relation.ForeignTable) + log.Debug().Str("query", q).Msgf("Getting keys for %s from target", table.Name) + + if primaryKeys, err := GetKeys(q, destination); err != nil { + log.Error().Err(err).Msgf("Error getting keys for %s", table.Name) + return err + } else { + if len(primaryKeys) == 0 { + + missingTable := TableByName(tables, relation.ForeignTable) + if err = relationalCopy(depth, tables, missingTable, visitedTables, source, destination); err != nil { + return errors.Wrapf(err, "Error copying table %s", missingTable.Name) + } + + // Retry short circuit + *depth++ + + log.Debug().Int("depth", *depth).Msgf("Retrying keys for %s", relation.ForeignTable) + if *depth < 1 { + goto retry + } else { + log.Debug().Str("table", relation.ForeignTable).Str("primary", relation.PrimaryTable).Msgf("No keys found at this time") + return errors.New("Max depth reached") + } + + } else { + *depth = 0 + keys := lo.Map(primaryKeys, func(key string, _ int) string { + return QuoteString(key) + }) + rq := fmt.Sprintf(`%s IN (%s)`, relation.PrimaryColumn, strings.Join(keys, ",")) + *relatedQueries = append(*relatedQueries, rq) + } + } + return nil +} + +func relationalCopy( + depth *int, + tables []Table, + table Table, + visitedTables *[]string, + source *pgxpool.Pool, + destination *pgxpool.Pool, +) error { + log.Debug().Str("table", table.Name).Msg("Preparing") + + relatedTables, err := TableGraph(table.Name, table.Relations) + if err != nil { + return errors.Wrapf(err, "Error sorting tables from graph") + } + log.Debug().Strs("tables", relatedTables).Msgf("Order of copy") + + for _, tableName := range relatedTables { + + if lo.Contains(*visitedTables, tableName) { + continue + } + + relatedTable := TableByName(tables, tableName) + *visitedTables = append(*visitedTables, relatedTable.Name) + // Use realized query to get primary keys that are already in the destination for all related tables + + // Selection query for this table + relatedQueries := []string{} + + for _, relation := range relatedTable.Relations { + err := relatedQueriesBuilder(depth, tables, relation, relatedTable, source, destination, visitedTables, &relatedQueries) + if err != nil { + return err + } + } + + if len(relatedQueries) > 0 { + log.Debug().Str("table", relatedTable.Name).Strs("relatedQueries", relatedQueries).Msg("Transferring with relationalCopy") + } + + if err = copyTableData(relatedTable, relatedQueries, false, source, destination); err != nil { + if condition, ok := err.(*pgconn.PgError); ok && condition.Code == "23503" { // foreign key violation + if err := relationalCopy(depth, tables, relatedTable, visitedTables, source, destination); err != nil { + return errors.Wrapf(err, "Error copying table %s", relatedTable.Name) + } + } + return errors.Wrapf(err, "Error copying table %s", relatedTable.Name) + } + + } + + return nil +} diff --git a/subsetter/graph.go b/subsetter/graph.go index ceba4bf..a04157a 100644 --- a/subsetter/graph.go +++ b/subsetter/graph.go @@ -24,3 +24,22 @@ func TableGraph(primary string, relations []Relation) (l []string, err error) { slices.Reverse(l) return } + +func RequiredTableGraph(primary string, relations []Relation) (l []string, err error) { + graph := topsort.NewGraph() // Create a new graph + + for _, r := range relations { + if !r.IsSelfRelated() { + err = graph.AddEdge(r.ForeignTable, r.PrimaryTable) + if err != nil { + return + } + } + } + l, err = graph.TopSort(primary) + if err != nil { + return + } + slices.Reverse(l) + return +} diff --git a/subsetter/info.go b/subsetter/info.go index 2517398..de37a22 100644 --- a/subsetter/info.go +++ b/subsetter/info.go @@ -4,21 +4,16 @@ import ( "fmt" "math" "strconv" + + "github.com/samber/lo" ) // GetTargetSet returns a subset of tables with the number of rows scaled by the fraction. func GetTargetSet(fraction float64, tables []Table) []Table { - var subset []Table - - for _, table := range tables { - subset = append(subset, Table{ - Name: table.Name, - Rows: int(math.Pow(10, math.Log10(float64(table.Rows))*fraction)), - Relations: table.Relations, - }) - } - - return subset + return lo.Map(tables, func(table Table, i int) Table { + table.Rows = int(math.Pow(10, math.Log10(float64(table.Rows))*fraction)) + return table + }) } func QuoteString(s string) string { diff --git a/subsetter/info_test.go b/subsetter/info_test.go index edc88de..c886b58 100644 --- a/subsetter/info_test.go +++ b/subsetter/info_test.go @@ -15,8 +15,12 @@ func TestGetTargetSet(t *testing.T) { tables []Table want []Table }{ - {"simple", 0.5, []Table{{"simple", 1000, []Relation{}}}, []Table{{"simple", 31, []Relation{}}}}, - {"simple", 0.5, []Table{{"simple", 10, []Relation{}}}, []Table{{"simple", 3, []Relation{}}}}, + {"simple", 0.5, + []Table{{"simple", 1000, []Relation{}, []Relation{}}}, + []Table{{"simple", 31, []Relation{}, []Relation{}}}}, + {"simple", 0.5, + []Table{{"simple", 10, []Relation{}, []Relation{}}}, + []Table{{"simple", 3, []Relation{}, []Relation{}}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/subsetter/query.go b/subsetter/query.go index 6507620..4d91d65 100644 --- a/subsetter/query.go +++ b/subsetter/query.go @@ -12,9 +12,10 @@ import ( ) type Table struct { - Name string - Rows int - Relations []Relation + Name string + Rows int + Relations []Relation + RequiredBy []Relation } // RelationNames returns a list of relation names in human readable format. @@ -38,13 +39,12 @@ func (t *Table) IsSelfRelated() bool { // IsSelfRelated returns true if a table is self related. func TableByName(tables []Table, name string) Table { - return lo.Filter(tables, func(table Table, _ int) bool { - return table.Name == name - })[0] + return lo.FindOrElse(tables, Table{}, func(t Table) bool { + return t.Name == name + }) } // GetTablesWithRows returns a list of tables with the number of rows in each table. -// Warning reltuples used to dermine size is an estimate of the number of rows in the table and can be zero for small tables. func GetTablesWithRows(conn *pgxpool.Pool) (tables []Table, err error) { q := `SELECT relname, @@ -79,10 +79,10 @@ func GetTablesWithRows(conn *pgxpool.Pool) (tables []Table, err error) { } // Get relations - table.Relations, err = GetRelations(table.Name, conn) - if err != nil { - return nil, err - } + table.Relations = GetRelations(table.Name, conn) + + // Get reverse relations + table.RequiredBy = GetRequiredBy(table.Name, conn) tables = append(tables, table) } @@ -153,13 +153,19 @@ func CopyQueryToString(query string, conn *pgxpool.Pool) (result string, err err // CopyTableToString copies a table to a string. func CopyTableToString(table string, limit string, where string, conn *pgxpool.Pool) (result string, err error) { - q := fmt.Sprintf(`SELECT * FROM %s %s order by random() %s`, table, where, limit) - log.Debug().Msgf("Query: %s", q) + maybeOrder := "" + if lo.IsNotEmpty(where) { + maybeOrder = "order by random()" + } + + q := fmt.Sprintf(`SELECT * FROM %s %s %s %s`, table, where, maybeOrder, limit) + log.Debug().Msgf("CopyTableToString query: %s", q) return CopyQueryToString(q, conn) } // CopyStringToTable copies a string to a table. func CopyStringToTable(table string, data string, conn *pgxpool.Pool) (err error) { + log.Debug().Msgf("CopyStringToTable query: %s", table) q := fmt.Sprintf(`copy %s from stdin`, table) var buff bytes.Buffer buff.WriteString(data) diff --git a/subsetter/query_test.go b/subsetter/query_test.go index aff3aea..2b6d086 100644 --- a/subsetter/query_test.go +++ b/subsetter/query_test.go @@ -17,7 +17,11 @@ func TestGetTablesWithRows(t *testing.T) { wantTables []Table wantErr bool }{ - {"With tables", conn, []Table{{"simple", 0, []Relation{}}, {"relation", 0, []Relation{}}}, false}, + {"With tables", conn, + []Table{ + {"simple", 0, []Relation{}, []Relation{}}, + {"relation", 0, []Relation{}, []Relation{}}, + }, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/subsetter/relations.go b/subsetter/relations.go index 5f69019..7a48239 100644 --- a/subsetter/relations.go +++ b/subsetter/relations.go @@ -5,11 +5,16 @@ import ( "fmt" "regexp" "strings" + "sync" "github.com/jackc/pgx/v5/pgxpool" + "github.com/rs/zerolog/log" "github.com/samber/lo" ) +var cachedRelations *[]RelationRaw +var mutexCachedRelations sync.Once + type Relation struct { PrimaryTable string PrimaryColumn string @@ -55,12 +60,12 @@ func (r *RelationRaw) toRelation() Relation { return rel } -// GetRelations returns a list of tables that have a foreign key for particular table. -func GetRelations(table string, conn *pgxpool.Pool) (relations []Relation, err error) { +func getAllRelations(table string, conn *pgxpool.Pool) *[]RelationRaw { - q := `SELECT + mutexCachedRelations.Do(func() { + q := `SELECT conrelid::regclass AS primary_table, - confrelid::regclass AS refrerenced_table, + confrelid::regclass AS referenced_table, pg_get_constraintdef(c.oid, TRUE) AS sql FROM pg_constraint c @@ -69,24 +74,45 @@ func GetRelations(table string, conn *pgxpool.Pool) (relations []Relation, err e c.contype = 'f' AND n.nspname = 'public';` - rows, err := conn.Query(context.Background(), q) - if err != nil { - return - } - defer rows.Close() - - for rows.Next() { - var rel RelationRaw - - err = rows.Scan(&rel.PrimaryTable, &rel.ForeignTable, &rel.SQL) + rows, err := conn.Query(context.Background(), q) if err != nil { return } + defer rows.Close() + relations := []RelationRaw{} + for rows.Next() { + var rel RelationRaw + + err = rows.Scan(&rel.PrimaryTable, &rel.ForeignTable, &rel.SQL) + if err != nil { + return + } + relations = append(relations, rel) + log.Debug().Str("table", rel.PrimaryTable).Str("foreign", rel.ForeignTable).Msg("Found relation") + } + cachedRelations = &relations + + }) + + return cachedRelations +} + +// GetRelations returns a list of tables that are foreign key for particular table. +func GetRelations(table string, conn *pgxpool.Pool) (relations []Relation) { + for _, rel := range *getAllRelations(table, conn) { if table == rel.PrimaryTable { relations = append(relations, rel.toRelation()) } - } + return +} +// GetRequiredBy returns a list of tables that have are foreign key for particular table. +func GetRequiredBy(table string, conn *pgxpool.Pool) (relations []Relation) { + for _, rel := range *getAllRelations(table, conn) { + if table == rel.ForeignTable { + relations = append(relations, rel.toRelation()) + } + } return } diff --git a/subsetter/relations_test.go b/subsetter/relations_test.go index 9eb01e6..30373d1 100644 --- a/subsetter/relations_test.go +++ b/subsetter/relations_test.go @@ -22,7 +22,7 @@ func TestGetRelations(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if gotRelations, _ := GetRelations(tt.table, tt.conn); !reflect.DeepEqual(gotRelations, tt.wantRelations) { + if gotRelations := GetRelations(tt.table, tt.conn); !reflect.DeepEqual(gotRelations, tt.wantRelations) { t.Errorf("GetRelations() = %v, want %v", gotRelations, tt.wantRelations) } }) diff --git a/subsetter/rule.go b/subsetter/rule.go new file mode 100644 index 0000000..3a7b627 --- /dev/null +++ b/subsetter/rule.go @@ -0,0 +1,120 @@ +package subsetter + +import ( + "fmt" + "strings" + + "github.com/pkg/errors" + "github.com/rs/zerolog/log" + "github.com/samber/lo" +) + +const RuleAll = "1=1" + +type Rule struct { + Table string + Where string +} + +func (r *Rule) String() string { + return fmt.Sprintf("%s:%s", r.Table, r.Where) +} + +func (r *Rule) Query(exclude []string) string { + if r.Where == "" { + return fmt.Sprintf("SELECT * FROM %s", r.Table) + } + + if len(exclude) > 0 { + exclude = lo.Map(exclude, func(s string, _ int) string { + return QuoteString(s) + }) + r.Where = fmt.Sprintf("%s AND id NOT IN (%s)", r.Where, strings.Join(exclude, ",")) + } + return fmt.Sprintf("SELECT * FROM %s WHERE %s", r.Table, r.Where) +} + +func GetPrimaryKeyNameRel(t Table, relatedTable string) string { + log.Debug().Str("table", t.Name).Str("relatedTable", relatedTable).Msg("Getting primary key name for related table") + for _, r := range t.RequiredBy { + if r.ForeignTable == relatedTable { + return r.PrimaryColumn + } + } + for _, r := range t.Relations { + if r.ForeignTable == relatedTable { + return r.PrimaryColumn + } + } + panic(fmt.Sprintf("No primary key found for table %s", t.Name)) +} + +func (r *Rule) QueryInclude(include []string, relatedTable Table) string { + q := fmt.Sprintf("SELECT * FROM %s", relatedTable.Name) + relatedTableKey := GetPrimaryKeyNameRel(relatedTable, r.Table) + + if len(include) > 0 { + include = lo.Map(include, func(s string, _ int) string { + return QuoteString(s) + }) + q = fmt.Sprintf("%s WHERE %s IN (%s)", q, relatedTableKey, strings.Join(include, ",")) + } + log.Debug().Str("query", q).Msgf("Query for related table %s", relatedTable.Name) + return q +} + +func (r *Rule) Copy(s *Sync) (err error) { + log.Debug().Str("query", r.Where).Msgf("Transferring forced rows for table %s", r.Table) + var data string + + keyName, err := GetPrimaryKeyName(r.Table, s.destination) + if err != nil { + return errors.Wrapf(err, "Error getting primary key for table %s", r.Table) + } + + q := fmt.Sprintf(`SELECT %s FROM %s`, keyName, r.Table) + log.Debug().Str("query", q).Msgf("Getting keys for %s from target", r.Table) + + excludedIDs := []string{} + if primaryKeys, err := GetKeys(q, s.destination); err == nil { + excludedIDs = primaryKeys + } + log.Debug().Strs("excludedIDs", excludedIDs).Msgf("Excluded IDs for table %s", r.Table) + + if data, err = CopyQueryToString(r.Query(excludedIDs), s.source); err != nil { + return errors.Wrapf(err, "Error copying forced rows for table %s", r.Table) + } + if err = CopyStringToTable(r.Table, data, s.destination); err != nil { + return errors.Wrapf(err, "Error inserting forced rows for table %s", r.Table) + } + log.Debug().Str("table", r.Table).Msgf("Transfered rows") + return +} + +func (r *Rule) CopyRelated(s *Sync, relatedTable Table) (err error) { + log.Debug().Str("query", r.Where).Msgf("Transferring forced rows for table %s", r.Table) + var data string + + keyName, err := GetPrimaryKeyName(r.Table, s.destination) + if err != nil { + return errors.Wrapf(err, "Error getting primary key for table %s", r.Table) + } + + q := fmt.Sprintf(`SELECT %s FROM %s WHERE %s`, keyName, r.Table, r.Where) + log.Debug().Str("query", q).Msgf("Getting keys for %s from target", r.Table) + + includedIDs := []string{} + if primaryKeys, err := GetKeys(q, s.source); err == nil { + includedIDs = primaryKeys + } + log.Debug().Strs("includedIDs", includedIDs).Str("table", relatedTable.Name).Msgf("Included IDs for table %s", r.Table) + + if data, err = CopyQueryToString(r.QueryInclude(includedIDs, relatedTable), s.source); err != nil { + return errors.Wrapf(err, "Error copying forced rows for table %s", r.Table) + } + if err = CopyStringToTable(r.Table, data, s.destination); err != nil { + return errors.Wrapf(err, "Error inserting forced rows for table %s", r.Table) + } + log.Debug().Str("table", relatedTable.Name).Msgf("Transfered related rows") + return +} diff --git a/subsetter/sync.go b/subsetter/sync.go index 716aee2..8d83ff1 100644 --- a/subsetter/sync.go +++ b/subsetter/sync.go @@ -3,9 +3,7 @@ package subsetter import ( "context" "fmt" - "strings" - "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" "github.com/pkg/errors" "github.com/rs/zerolog/log" @@ -29,56 +27,6 @@ type Sync struct { exclude []Rule } -type Rule struct { - Table string - Where string -} - -func (r *Rule) String() string { - return fmt.Sprintf("%s:%s", r.Table, r.Where) -} - -func (r *Rule) Query(exclude []string) string { - if r.Where == "" { - return fmt.Sprintf("SELECT * FROM %s", r.Table) - } - - if len(exclude) > 0 { - exclude = lo.Map(exclude, func(s string, _ int) string { - return QuoteString(s) - }) - r.Where = fmt.Sprintf("%s AND id NOT IN (%s)", r.Where, strings.Join(exclude, ",")) - } - return fmt.Sprintf("SELECT * FROM %s WHERE %s", r.Table, r.Where) -} - -func (r *Rule) Copy(s *Sync) (err error) { - log.Debug().Str("query", r.Where).Msgf("Transferring forced rows for table %s", r.Table) - var data string - - keyName, err := GetPrimaryKeyName(r.Table, s.destination) - if err != nil { - return errors.Wrapf(err, "Error getting primary key for table %s", r.Table) - } - - q := fmt.Sprintf(`SELECT %s FROM %s`, keyName, r.Table) - log.Debug().Str("query", q).Msgf("Getting keys for %s from target", r.Table) - - excludedIDs := []string{} - if primaryKeys, err := GetKeys(q, s.destination); err == nil { - excludedIDs = primaryKeys - } - log.Debug().Strs("excludedIDs", excludedIDs).Msgf("Excluded IDs for table %s", r.Table) - - if data, err = CopyQueryToString(r.Query(excludedIDs), s.source); err != nil { - return errors.Wrapf(err, "Error copying forced rows for table %s", r.Table) - } - if err = CopyStringToTable(r.Table, data, s.destination); err != nil { - return errors.Wrapf(err, "Error inserting forced rows for table %s", r.Table) - } - return -} - func NewSync(source string, target string, fraction float64, include []Rule, exclude []Rule, verbose bool) (*Sync, error) { src, err := pgxpool.New(context.Background(), source) if err != nil { @@ -115,153 +63,47 @@ func (s *Sync) Close() { s.destination.Close() } -// copyTableData copies the data from a table in the source database to the destination database -func copyTableData(table Table, relatedQueries []string, withLimit bool, source *pgxpool.Pool, destination *pgxpool.Pool) (err error) { - // Backtrace the inserted ids from main table to related table - subSelectQuery := "" - if len(relatedQueries) > 0 { - subSelectQuery = "WHERE " + strings.Join(relatedQueries, " AND ") - } - - limit := "" - if withLimit { - limit = fmt.Sprintf("LIMIT %d", table.Rows) - } - - var data string - if data, err = CopyTableToString(table.Name, limit, subSelectQuery, source); err != nil { - //log.Error().Err(err).Str("table", table.Name).Msg("Error getting table data") - return - } - if err = CopyStringToTable(table.Name, data, destination); err != nil { - //log.Error().Err(err).Str("table", table.Name).Msg("Error pushing table data") - return - } - return - -} - -func relatedQueriesBuilder( - depth *int, - tables []Table, - relation Relation, - table Table, - source *pgxpool.Pool, - destination *pgxpool.Pool, - visitedTables *[]string, - relatedQueries *[]string, -) (err error) { - -retry: - q := fmt.Sprintf(`SELECT %s FROM %s`, relation.ForeignColumn, relation.ForeignTable) - log.Debug().Str("query", q).Msgf("Getting keys for %s from target", table.Name) - - if primaryKeys, err := GetKeys(q, destination); err != nil { - log.Error().Err(err).Msgf("Error getting keys for %s", table.Name) - return err - } else { - if len(primaryKeys) == 0 { - - missingTable := TableByName(tables, relation.ForeignTable) - if err = RelationalCopy(depth, tables, missingTable, visitedTables, source, destination); err != nil { - return errors.Wrapf(err, "Error copying table %s", missingTable.Name) - } - - // Retry short circuit - *depth++ - - log.Debug().Int("depth", *depth).Msgf("Retrying keys for %s", relation.ForeignTable) - if *depth < 1 { - goto retry - } else { - log.Debug().Str("table", relation.ForeignTable).Str("primary", relation.PrimaryTable).Msgf("No keys found at this time") - return errors.New("Max depth reached") - } - - } else { - *depth = 0 - keys := lo.Map(primaryKeys, func(key string, _ int) string { - return QuoteString(key) - }) - rq := fmt.Sprintf(`%s IN (%s)`, relation.PrimaryColumn, strings.Join(keys, ",")) - *relatedQueries = append(*relatedQueries, rq) - } - } - return nil -} - -func RelationalCopy( - depth *int, - tables []Table, - table Table, - visitedTables *[]string, - source *pgxpool.Pool, - destination *pgxpool.Pool, -) error { - log.Debug().Str("table", table.Name).Msg("Preparing") - - relatedTables, err := TableGraph(table.Name, table.Relations) - if err != nil { - return errors.Wrapf(err, "Error sorting tables from graph") - } - log.Debug().Strs("tables", relatedTables).Msgf("Order of copy") - - for _, tableName := range relatedTables { - - if lo.Contains(*visitedTables, tableName) { - continue - } - - relatedTable := TableByName(tables, tableName) - *visitedTables = append(*visitedTables, relatedTable.Name) - // Use realized query to get primary keys that are already in the destination for all related tables - - // Selection query for this table - relatedQueries := []string{} - - for _, relation := range relatedTable.Relations { - err := relatedQueriesBuilder(depth, tables, relation, relatedTable, source, destination, visitedTables, &relatedQueries) - if err != nil { - return err - } - } - - if len(relatedQueries) > 0 { - log.Debug().Str("table", relatedTable.Name).Strs("relatedQueries", relatedQueries).Msg("Transferring with RelationalCopy") - } - - if err = copyTableData(relatedTable, relatedQueries, false, source, destination); err != nil { - if condition, ok := err.(*pgconn.PgError); ok && condition.Code == "23503" { // foreign key violation - if err := RelationalCopy(depth, tables, relatedTable, visitedTables, source, destination); err != nil { - return errors.Wrapf(err, "Error copying table %s", relatedTable.Name) - } - } - return errors.Wrapf(err, "Error copying table %s", relatedTable.Name) - } - - } - - return nil -} - // CopyTables copies the data from a list of tables in the source database to the destination database func (s *Sync) CopyTables(tables []Table) (err error) { + // Filter out tables that are in include list and have custom rule + customRuleTables := lo.Uniq(lo.Map(s.include, func(rule Rule, _ int) string { + return rule.Table + })) + visitedTables := []string{} + // Copy tables without relations first for _, table := range lo.Filter(tables, func(table Table, _ int) bool { return len(table.Relations) == 0 }) { log.Info().Str("table", table.Name).Msg("Transferring") - if err = copyTableData(table, []string{}, true, s.source, s.destination); err != nil { - return errors.Wrapf(err, "Error copying table %s", table.Name) - } - - for _, include := range s.include { - if include.Table == table.Name { - err = include.Copy(s) - if err != nil { - return errors.Wrapf(err, "Error copying forced rows for table %s", table.Name) + if !lo.Contains(customRuleTables, table.Name) { + if err = copyTableData(table, []string{}, true, s.source, s.destination); err != nil { + return errors.Wrapf(err, "Error copying table %s", table.Name) + } + } else { + for _, include := range s.include { + if include.Table == table.Name { + err = include.Copy(s) + if err != nil { + return errors.Wrapf(err, "Error copying forced rows for table %s", table.Name) + } + if include.Where != RuleAll { + // reverse copy all related rows + requiredTables, _ := RequiredTableGraph(table.Name, table.RequiredBy) + for _, relation := range requiredTables { + if relation == table.Name { // skip self + continue + } + relatedTable := TableByName(tables, relation) + if relatedTable.Name == "" { // skip self + continue + } + include.CopyRelated(s, relatedTable) + } + //panic("not implemented") + } } } } @@ -278,7 +120,7 @@ func (s *Sync) CopyTables(tables []Table) (err error) { return len(table.Relations) > 0 }) { log.Info().Str("table", complexTable.Name).Msg("Transferring") - if err := RelationalCopy(&depth, tables, complexTable, &visitedTables, s.source, s.destination); err != nil { + if err := relationalCopy(&depth, tables, complexTable, &visitedTables, s.source, s.destination); err != nil { log.Info().Str("table", complexTable.Name).Msgf("Transferring failed, retrying later") maybeRetry = append(maybeRetry, complexTable) } @@ -294,7 +136,7 @@ func (s *Sync) CopyTables(tables []Table) (err error) { visitedRetriedTables := []string{} for _, retiredTable := range maybeRetry { log.Info().Str("table", retiredTable.Name).Msg("Transferring") - if err := RelationalCopy(&depth, tables, retiredTable, &visitedRetriedTables, s.source, s.destination); err != nil { + if err := relationalCopy(&depth, tables, retiredTable, &visitedRetriedTables, s.source, s.destination); err != nil { log.Warn().Str("table", retiredTable.Name).Msgf("Transferring failed, try increasing fraction percentage") } } @@ -337,7 +179,7 @@ func (s *Sync) Sync() (err error) { return !lo.Contains(ruleExcludedTables, table.Name) // excluded tables }) - // Calculate fraction to be coped over + // Calculate fraction to be copied over if tables = GetTargetSet(s.fraction, tables); err != nil { return } diff --git a/subsetter/sync_test.go b/subsetter/sync_test.go index bffdb7b..4f873e5 100644 --- a/subsetter/sync_test.go +++ b/subsetter/sync_test.go @@ -18,7 +18,7 @@ func TestSync_CopyTables(t *testing.T) { source: src, destination: dst, } - tables := []Table{{"simple", 10, []Relation{}}} + tables := []Table{{"simple", 10, []Relation{}, []Relation{}}} if err := s.CopyTables(tables); err != nil { t.Errorf("Sync.CopyTables() error = %v", err)