From cf32d301bcd113a2c0963c6ce4170424ba2a4d02 Mon Sep 17 00:00:00 2001 From: Tonda Pleskac Date: Wed, 22 Feb 2023 08:44:56 +0100 Subject: [PATCH] Skip generated columns in copy_rows method --- psql_database_helper.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/psql_database_helper.py b/psql_database_helper.py index 360b181..f57572c 100644 --- a/psql_database_helper.py +++ b/psql_database_helper.py @@ -20,6 +20,9 @@ def turn_off_constraints(connection): def copy_rows(source, destination, query, destination_table): datatypes = get_table_datatypes(table_name(destination_table), schema_name(destination_table), destination) + non_generated_columns = [(dt[0], dt[1]) for i, dt in enumerate(datatypes) if dt[2] != 's'] + generated_columns_positions = [i for i, dt in enumerate(datatypes) if 's' in dt[2]] + def template_piece(dt): if dt == '_json': return '%s::json[]' @@ -28,8 +31,8 @@ def template_piece(dt): else: return '%s' - template = '(' + ','.join([template_piece(dt) for dt in datatypes]) + ')' - + template = '(' + ','.join([template_piece(dt[1]) for dt in non_generated_columns]) + ')' + columns = '(' + ','.join([dt[0] for dt in non_generated_columns]) + ')' cursor_name='table_cursor_'+str(uuid.uuid4()).replace('-','') cursor = source.cursor(name=cursor_name) @@ -41,13 +44,14 @@ def template_piece(dt): if len(rows) == 0: break - # we end up doing a lot of execute statements here, copying data. # using the inner_cursor means we don't log all the noise destination_cursor = destination.cursor().inner_cursor - insert_query = 'INSERT INTO {} VALUES %s'.format(fully_qualified_table(destination_table)) + insert_query = 'INSERT INTO {} {} VALUES %s'.format(fully_qualified_table(destination_table), columns) - execute_values(destination_cursor, insert_query, rows, template) + updated_rows = [tuple(val for i, val in enumerate(row) if i not in generated_columns_positions) for row in rows] + + execute_values(destination_cursor, insert_query, updated_rows, template) destination_cursor.close() @@ -186,7 +190,7 @@ def get_table_datatypes(table, schema, conn): else: table_clause = "cl.relname = '{}' AND ns.nspname = '{}'".format(table, schema) with conn.cursor() as cur: - cur.execute("""SELECT ty.typname + cur.execute("""SELECT att.attname, ty.typname, att.attgenerated FROM pg_attribute att JOIN pg_class cl ON cl.oid = att.attrelid JOIN pg_type ty ON ty.oid = att.atttypid @@ -196,7 +200,7 @@ def get_table_datatypes(table, schema, conn): ORDER BY att.attnum; """.format(table_clause)) - return [r[0] for r in cur.fetchall()] + return [(r[0], r[1], r[2]) for r in cur.fetchall()] def truncate_table(target_table, conn): with conn.cursor() as cur: