From 8bb192e9602c154406d74eb1c613716d3dfe1090 Mon Sep 17 00:00:00 2001 From: "Michael S. Molina" Date: Tue, 6 Aug 2024 11:33:44 -0400 Subject: [PATCH] Handles SQLite --- superset/migrations/shared/catalogs.py | 86 +++++++++++++++++++------- 1 file changed, 62 insertions(+), 24 deletions(-) diff --git a/superset/migrations/shared/catalogs.py b/superset/migrations/shared/catalogs.py index 28a944707e776..2235f6228a45c 100644 --- a/superset/migrations/shared/catalogs.py +++ b/superset/migrations/shared/catalogs.py @@ -97,8 +97,6 @@ class Slice(Base): ModelType = Union[Type[Query], Type[SavedQuery], Type[TabState], Type[TableSchema]] -BATCH_SIZE = 10000 - MODELS: list[tuple[ModelType, str]] = [ (Query, "database_id"), (SavedQuery, "db_id"), @@ -124,8 +122,17 @@ def get_known_schemas(database_name: str, session: Session) -> list[str]: return sorted({name[0][1:-1].split("].[")[-1] for name in names}) +def get_batch_size(session: Session) -> int: + max_sqlite_in = 999 + return max_sqlite_in if session.bind.dialect.name == "sqlite" else 10000 + + def print_processed_batch( - start_time: datetime, offset: int, total_rows: int, model: ModelType + start_time: datetime, + offset: int, + total_rows: int, + model: ModelType, + batch_size: int, ) -> None: """ Print the progress of batch processing. @@ -139,11 +146,12 @@ def print_processed_batch( offset (int): The current offset in the batch processing. total_rows (int): The total number of rows to process. model (ModelType): The model being processed. + batch_size (int): The size of the batch being processed. """ elapsed_time = datetime.now() - start_time elapsed_seconds = elapsed_time.total_seconds() elapsed_formatted = f"{int(elapsed_seconds // 3600):02}:{int((elapsed_seconds % 3600) // 60):02}:{int(elapsed_seconds % 60):02}" - rows_processed = min(offset + BATCH_SIZE, total_rows) + rows_processed = min(offset + batch_size, total_rows) logger.info( f"{elapsed_formatted} - {rows_processed:,} of {total_rows:,} {model.__tablename__} rows processed " f"({(rows_processed / total_rows) * 100:.2f}%)" @@ -158,7 +166,7 @@ def update_catalog_column( This function iterates over a list of models defined by MODELS and updates the `catalog` columnto the specified catalog or None depending on the downgrade - parameter. The update is performedin batches to optimize performance and reduce + parameter. The update is performed in batches to optimize performance and reduce memory usage. Parameters: @@ -186,26 +194,41 @@ def update_catalog_column( f"Total rows to be processed for {model.__tablename__}: {total_rows:,}" ) + batch_size = get_batch_size(session) + # Update in batches using row numbers - for i in range(0, total_rows, BATCH_SIZE): + for i in range(0, total_rows, batch_size): subquery = ( session.query(model.id) .filter(getattr(model, column) == database.id) .filter(model.catalog == catalog if downgrade else True) .order_by(model.id) .offset(i) - .limit(BATCH_SIZE) + .limit(batch_size) .subquery() ) - session.execute( - sa.update(model) - .where(model.id == subquery.c.id) - .values(catalog=None if downgrade else catalog) - .execution_options(synchronize_session=False) - ) + + # SQLite does not support multiple-table criteria within UPDATE + if session.bind.dialect.name == "sqlite": + ids_to_update = [row.id for row in session.query(subquery.c.id).all()] + if ids_to_update: + session.execute( + sa.update(model) + .where(model.id.in_(ids_to_update)) + .values(catalog=None if downgrade else catalog) + .execution_options(synchronize_session=False) + ) + else: + session.execute( + sa.update(model) + .where(model.id == subquery.c.id) + .values(catalog=None if downgrade else catalog) + .execution_options(synchronize_session=False) + ) + # Commit the transaction after each batch session.commit() - print_processed_batch(start_time, i, total_rows, model) + print_processed_batch(start_time, i, total_rows, model, batch_size) def update_schema_catalog_perms( @@ -257,8 +280,10 @@ def update_schema_catalog_perms( .filter(Database.id == database.id) .filter(Slice.datasource_type == "table") ): - chart.catalog_perm = catalog_perm - chart.schema_perm = mapping[chart.datasource_id] + # We only care about tables that exist in the mapping + if mapping.get(chart.datasource_id) is not None: + chart.catalog_perm = catalog_perm + chart.schema_perm = mapping[chart.datasource_id] def delete_models_non_default_catalog( @@ -292,25 +317,38 @@ def delete_models_non_default_catalog( f"Total rows to be processed for {model.__tablename__}: {total_rows:,}" ) + batch_size = get_batch_size(session) + # Update in batches using row numbers - for i in range(0, total_rows, BATCH_SIZE): + for i in range(0, total_rows, batch_size): subquery = ( session.query(model.id) .filter(getattr(model, column) == database.id) .filter(model.catalog != catalog) .order_by(model.id) .offset(i) - .limit(BATCH_SIZE) + .limit(batch_size) .subquery() ) - session.execute( - sa.delete(model) - .where(model.id == subquery.c.id) - .execution_options(synchronize_session=False) - ) + + # SQLite does not support multiple-table criteria within DELETE + if session.bind.dialect.name == "sqlite": + ids_to_delete = [row.id for row in session.query(subquery.c.id).all()] + if ids_to_delete: + session.execute( + sa.delete(model) + .where(model.id.in_(ids_to_delete)) + .execution_options(synchronize_session=False) + ) + else: + session.execute( + sa.delete(model) + .where(model.id == subquery.c.id) + .execution_options(synchronize_session=False) + ) # Commit the transaction after each batch session.commit() - print_processed_batch(start_time, i, total_rows, model) + print_processed_batch(start_time, i, total_rows, model, batch_size) def upgrade_catalog_perms(engines: set[str] | None = None) -> None: