Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip generated columns in copy_rows method #41

Merged
merged 2 commits into from
Mar 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions psql_database_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def turn_off_constraints(connection):
def copy_rows(source, destination, query, destination_table):
datatypes = get_table_datatypes(table_name(destination_table), schema_name(destination_table), destination)

non_generated_columns = [(dt[0], dt[1]) for i, dt in enumerate(datatypes) if dt[2] != 's']
generated_columns_positions = [i for i, dt in enumerate(datatypes) if 's' in dt[2]]
always_generated_id = any([dt[3] == 'a' for dt in datatypes])

def template_piece(dt):
if dt == '_json':
return '%s::json[]'
Expand All @@ -28,8 +32,8 @@ def template_piece(dt):
else:
return '%s'

template = '(' + ','.join([template_piece(dt) for dt in datatypes]) + ')'

template = '(' + ','.join([template_piece(dt[1]) for dt in non_generated_columns]) + ')'
columns = '(' + ','.join([dt[0] for dt in non_generated_columns]) + ')'

cursor_name='table_cursor_'+str(uuid.uuid4()).replace('-','')
cursor = source.cursor(name=cursor_name)
Expand All @@ -41,13 +45,16 @@ def template_piece(dt):
if len(rows) == 0:
break

# we end up doing a lot of execute statements here, copying data.
# using the inner_cursor means we don't log all the noise
destination_cursor = destination.cursor().inner_cursor

insert_query = 'INSERT INTO {} VALUES %s'.format(fully_qualified_table(destination_table))
insert_query = 'INSERT INTO {} {} VALUES %s'.format(fully_qualified_table(destination_table), columns)
if (always_generated_id):
insert_query = 'INSERT INTO {} {} OVERRIDING SYSTEM VALUE VALUES %s'.format(fully_qualified_table(destination_table), columns)

execute_values(destination_cursor, insert_query, rows, template)
updated_rows = [tuple(val for i, val in enumerate(row) if i not in generated_columns_positions) for row in rows]

execute_values(destination_cursor, insert_query, updated_rows, template)

destination_cursor.close()

Expand Down Expand Up @@ -186,7 +193,7 @@ def get_table_datatypes(table, schema, conn):
else:
table_clause = "cl.relname = '{}' AND ns.nspname = '{}'".format(table, schema)
with conn.cursor() as cur:
cur.execute("""SELECT ty.typname
cur.execute("""SELECT att.attname, ty.typname, att.attgenerated, att.attidentity
FROM pg_attribute att
JOIN pg_class cl ON cl.oid = att.attrelid
JOIN pg_type ty ON ty.oid = att.atttypid
Expand All @@ -196,7 +203,7 @@ def get_table_datatypes(table, schema, conn):
ORDER BY att.attnum;
""".format(table_clause))

return [r[0] for r in cur.fetchall()]
return [(r[0], r[1], r[2], r[3]) for r in cur.fetchall()]

def truncate_table(target_table, conn):
with conn.cursor() as cur:
Expand Down