From 2bd9015ea98bfdde17a91d3989f8c1a7c8a80b86 Mon Sep 17 00:00:00 2001 From: "Michael S. Molina" Date: Mon, 5 Aug 2024 13:32:55 -0400 Subject: [PATCH] fix: upgrade_catalog_perms implementation --- superset/migrations/shared/catalogs.py | 185 ++++++++++++++++++++----- 1 file changed, 153 insertions(+), 32 deletions(-) diff --git a/superset/migrations/shared/catalogs.py b/superset/migrations/shared/catalogs.py index b09c71739f8b7..bd2fd4af31c21 100644 --- a/superset/migrations/shared/catalogs.py +++ b/superset/migrations/shared/catalogs.py @@ -18,6 +18,7 @@ from __future__ import annotations import logging +from datetime import datetime from typing import Any, Type import sqlalchemy as sa @@ -35,8 +36,7 @@ ) from superset.models.core import Database -logger = logging.getLogger(__name__) - +logger = logging.getLogger("alembic") Base: Type[Any] = declarative_base() @@ -95,6 +95,18 @@ class Slice(Base): schema_perm = sa.Column(sa.String(1000)) +ModelType = Type[Query] | Type[SavedQuery] | Type[TabState] | Type[TableSchema] + +BATCH_SIZE = 10000 + +MODELS: list[tuple[ModelType, str]] = [ + (Query, "database_id"), + (SavedQuery, "db_id"), + (TabState, "database_id"), + (TableSchema, "database_id"), +] + + def get_known_schemas(database_name: str, session: Session) -> list[str]: """ Read all known schemas from the existing schema permissions. @@ -112,6 +124,142 @@ def get_known_schemas(database_name: str, session: Session) -> list[str]: return sorted({name[0][1:-1].split("].[")[-1] for name in names}) +def print_processed_batch( + start_time: datetime, offset: int, total_rows: int, model: ModelType +) -> None: + """ + Print the progress of batch processing. + + This function logs the progress of processing a batch of rows from a model. + It calculates the elapsed time since the start of the batch processing and + logs the number of rows processed along with the percentage completion. + + Parameters: + start_time (datetime): The start time of the batch processing. + offset (int): The current offset in the batch processing. + total_rows (int): The total number of rows to process. + model (ModelType): The model being processed. + """ + elapsed_time = datetime.now() - start_time + elapsed_seconds = elapsed_time.total_seconds() + elapsed_formatted = f"{int(elapsed_seconds // 3600):02}:{int((elapsed_seconds % 3600) // 60):02}:{int(elapsed_seconds % 60):02}" + rows_processed = min(offset + BATCH_SIZE, total_rows) + logger.info( + f"{elapsed_formatted} - {rows_processed:,} of {total_rows:,} {model.__tablename__} rows processed " + f"({(rows_processed / total_rows) * 100:.2f}%)" + ) + + +def update_catalog_column( + session: Session, database: Database, catalog: str, downgrade: bool = False +) -> None: + """ + Update the `catalog` column in the specified models to the given catalog. + + This function iterates over a list of models defined by MODELS and updates + the `catalog` columnto the specified catalog or None depending on the downgrade + parameter. The update is performedin batches to optimize performance and reduce + memory usage. + + Parameters: + session (Session): The SQLAlchemy session to use for database operations. + database (Database): The database instance containing the models to update. + catalog (Catalog): The new catalog value to set in the `catalog` column or + the default catalog if `downgrade` is True. + downgrade (bool): If True, the `catalog` column is set to None where the + catalog matches the specified catalog. + """ + start_time = datetime.now() + + logger.info(f"Updating {database.database_name} models to catalog {catalog}") + + for model, column in MODELS: + # Get the total number of rows that match the condition + total_rows = ( + session.query(sa.func.count(model.id)) + .filter(getattr(model, column) == database.id) + .filter(model.catalog == catalog if downgrade else True) + .scalar() + ) + + logger.info( + f"Total rows to be processed for {model.__tablename__}: {total_rows:,}" + ) + + # Update in batches using row numbers + for i in range(0, total_rows, BATCH_SIZE): + subquery = ( + session.query(model.id) + .filter(getattr(model, column) == database.id) + .filter(model.catalog == catalog if downgrade else True) + .order_by(model.id) + .offset(i) + .limit(BATCH_SIZE) + .subquery() + ) + session.execute( + sa.update(model) + .where(model.id == subquery.c.id) + .values(catalog=None if downgrade else catalog) + .execution_options(synchronize_session=False) + ) + # Commit the transaction after each batch + session.commit() + print_processed_batch(start_time, i, total_rows, model) + + +def delete_models_non_default_catalog( + session: Session, database: Database, catalog: str +) -> None: + """ + Delete models that are not in the default catalog. + + This function iterates over a list of models defined by MODELS and deletes + the rows where the `catalog` column does not match the specified catalog. + + Parameters: + session (Session): The SQLAlchemy session to use for database operations. + database (Database): The database instance containing the models to delete. + catalog (Catalog): The catalog to use to filter the models to delete. + """ + start_time = datetime.now() + + logger.info(f"Deleting models not in the default catalog: {catalog}") + + for model, column in MODELS: + # Get the total number of rows that match the condition + total_rows = ( + session.query(sa.func.count(model.id)) + .filter(getattr(model, column) == database.id) + .filter(model.catalog != catalog) + .scalar() + ) + + logger.info( + f"Total rows to be processed for {model.__tablename__}: {total_rows:,}" + ) + + # Update in batches using row numbers + for i in range(0, total_rows, BATCH_SIZE): + subquery = ( + session.query(model.id) + .filter(getattr(model, column) == database.id) + .filter(model.catalog != catalog) + .order_by(model.id) + .offset(i) + .limit(BATCH_SIZE) + .subquery() + ) + session.execute( + sa.delete(model) + .where(model.id == subquery.c.id) + .execution_options(synchronize_session=False) + ) + # Commit the transaction after each batch + session.commit() + print_processed_batch(start_time, i, total_rows, model) + + def upgrade_catalog_perms(engines: set[str] | None = None) -> None: """ Update models and permissions when catalogs are introduced in a DB engine spec. @@ -170,17 +318,7 @@ def upgrade_database_catalogs( # update existing models that have a `catalog` column so it points to the default # catalog - models = [ - (Query, "database_id"), - (SavedQuery, "db_id"), - (TabState, "database_id"), - (TableSchema, "database_id"), - ] - for model, column in models: - for instance in session.query(model).filter( - getattr(model, column) == database.id - ): - instance.catalog = default_catalog + update_catalog_column(session, database, default_catalog, False) # update `schema_perm` and `catalog_perm` for tables and charts for table in session.query(SqlaTable).filter_by( @@ -374,19 +512,7 @@ def downgrade_database_catalogs( # permissions associated with other catalogs downgrade_schema_perms(database, default_catalog, session) - # update existing models - models = [ - (Query, "database_id"), - (SavedQuery, "db_id"), - (TabState, "database_id"), - (TableSchema, "database_id"), - ] - for model, column in models: - for instance in session.query(model).filter( - getattr(model, column) == database.id, - model.catalog == default_catalog, # type: ignore - ): - instance.catalog = None + update_catalog_column(session, database, default_catalog, True) # update `schema_perm` for tables and charts for table in session.query(SqlaTable).filter_by( @@ -411,12 +537,7 @@ def downgrade_database_catalogs( chart.schema_perm = schema_perm # delete models referencing non-default catalogs - for model, column in models: - for instance in session.query(model).filter( - getattr(model, column) == database.id, - model.catalog != default_catalog, # type: ignore - ): - session.delete(instance) + delete_models_non_default_catalog(session, database, default_catalog) # delete datasets and any associated permissions for table in session.query(SqlaTable).filter(