Skip to content

Commit

Permalink
Skip generated columns in copy_rows method
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceda committed Feb 22, 2023
1 parent 4dafc49 commit 7c33238
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions psql_database_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def turn_off_constraints(connection):
def copy_rows(source, destination, query, destination_table):
datatypes = get_table_datatypes(table_name(destination_table), schema_name(destination_table), destination)

non_generated_columns = [(dt[0], dt[1]) for i, dt in enumerate(datatypes) if dt[2] != 's']
generated_columns_positions = [i for i, dt in enumerate(datatypes) if 's' in dt[2]]

def template_piece(dt):
if dt == '_json':
return '%s::json[]'
Expand All @@ -28,8 +31,8 @@ def template_piece(dt):
else:
return '%s'

template = '(' + ','.join([template_piece(dt) for dt in datatypes]) + ')'

template = '(' + ','.join([template_piece(dt[1]) for dt in non_generated_columns]) + ')'
columns = '(' + ','.join([dt[0] for dt in non_generated_columns]) + ')'

cursor_name='table_cursor_'+str(uuid.uuid4()).replace('-','')
cursor = source.cursor(name=cursor_name)
Expand All @@ -41,13 +44,14 @@ def template_piece(dt):
if len(rows) == 0:
break

# we end up doing a lot of execute statements here, copying data.
# using the inner_cursor means we don't log all the noise
destination_cursor = destination.cursor().inner_cursor

insert_query = 'INSERT INTO {} VALUES %s'.format(fully_qualified_table(destination_table))
insert_query = 'INSERT INTO {} {} VALUES %s'.format(fully_qualified_table(destination_table), columns)

updated_rows = [tuple(val for i, val in enumerate(row) if i not in generated_columns_positions) for row in rows]

execute_values(destination_cursor, insert_query, rows, template)
execute_values(destination_cursor, insert_query, updated_rows, template)

destination_cursor.close()

Expand Down Expand Up @@ -186,7 +190,7 @@ def get_table_datatypes(table, schema, conn):
else:
table_clause = "cl.relname = '{}' AND ns.nspname = '{}'".format(table, schema)
with conn.cursor() as cur:
cur.execute("""SELECT ty.typname
cur.execute("""SELECT att.attname, ty.typname, att.attgenerated
FROM pg_attribute att
JOIN pg_class cl ON cl.oid = att.attrelid
JOIN pg_type ty ON ty.oid = att.atttypid
Expand All @@ -196,7 +200,8 @@ def get_table_datatypes(table, schema, conn):
ORDER BY att.attnum;
""".format(table_clause))

return [r[0] for r in cur.fetchall()]
return [(r[0], r[1], r[2]) for r in cur.fetchall()]


def truncate_table(target_table, conn):
with conn.cursor() as cur:
Expand Down

0 comments on commit 7c33238

Please sign in to comment.