Skip to content

Commit

Permalink
fix: upgrade_catalog_perms implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-s-molina committed Aug 6, 2024
1 parent c7dc4dc commit 2bd9015
Showing 1 changed file with 153 additions and 32 deletions.
185 changes: 153 additions & 32 deletions superset/migrations/shared/catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import logging
from datetime import datetime
from typing import Any, Type

import sqlalchemy as sa
Expand All @@ -35,8 +36,7 @@
)
from superset.models.core import Database

logger = logging.getLogger(__name__)

logger = logging.getLogger("alembic")

Base: Type[Any] = declarative_base()

Expand Down Expand Up @@ -95,6 +95,18 @@ class Slice(Base):
schema_perm = sa.Column(sa.String(1000))


ModelType = Type[Query] | Type[SavedQuery] | Type[TabState] | Type[TableSchema]

BATCH_SIZE = 10000

MODELS: list[tuple[ModelType, str]] = [
(Query, "database_id"),
(SavedQuery, "db_id"),
(TabState, "database_id"),
(TableSchema, "database_id"),
]


def get_known_schemas(database_name: str, session: Session) -> list[str]:
"""
Read all known schemas from the existing schema permissions.
Expand All @@ -112,6 +124,142 @@ def get_known_schemas(database_name: str, session: Session) -> list[str]:
return sorted({name[0][1:-1].split("].[")[-1] for name in names})


def print_processed_batch(
start_time: datetime, offset: int, total_rows: int, model: ModelType
) -> None:
"""
Print the progress of batch processing.
This function logs the progress of processing a batch of rows from a model.
It calculates the elapsed time since the start of the batch processing and
logs the number of rows processed along with the percentage completion.
Parameters:
start_time (datetime): The start time of the batch processing.
offset (int): The current offset in the batch processing.
total_rows (int): The total number of rows to process.
model (ModelType): The model being processed.
"""
elapsed_time = datetime.now() - start_time
elapsed_seconds = elapsed_time.total_seconds()
elapsed_formatted = f"{int(elapsed_seconds // 3600):02}:{int((elapsed_seconds % 3600) // 60):02}:{int(elapsed_seconds % 60):02}"
rows_processed = min(offset + BATCH_SIZE, total_rows)
logger.info(
f"{elapsed_formatted} - {rows_processed:,} of {total_rows:,} {model.__tablename__} rows processed "
f"({(rows_processed / total_rows) * 100:.2f}%)"
)


def update_catalog_column(
session: Session, database: Database, catalog: str, downgrade: bool = False
) -> None:
"""
Update the `catalog` column in the specified models to the given catalog.
This function iterates over a list of models defined by MODELS and updates
the `catalog` columnto the specified catalog or None depending on the downgrade
parameter. The update is performedin batches to optimize performance and reduce
memory usage.
Parameters:
session (Session): The SQLAlchemy session to use for database operations.
database (Database): The database instance containing the models to update.
catalog (Catalog): The new catalog value to set in the `catalog` column or
the default catalog if `downgrade` is True.
downgrade (bool): If True, the `catalog` column is set to None where the
catalog matches the specified catalog.
"""
start_time = datetime.now()

logger.info(f"Updating {database.database_name} models to catalog {catalog}")

for model, column in MODELS:
# Get the total number of rows that match the condition
total_rows = (
session.query(sa.func.count(model.id))
.filter(getattr(model, column) == database.id)
.filter(model.catalog == catalog if downgrade else True)
.scalar()
)

logger.info(
f"Total rows to be processed for {model.__tablename__}: {total_rows:,}"
)

# Update in batches using row numbers
for i in range(0, total_rows, BATCH_SIZE):
subquery = (
session.query(model.id)
.filter(getattr(model, column) == database.id)
.filter(model.catalog == catalog if downgrade else True)
.order_by(model.id)
.offset(i)
.limit(BATCH_SIZE)
.subquery()
)
session.execute(
sa.update(model)
.where(model.id == subquery.c.id)
.values(catalog=None if downgrade else catalog)
.execution_options(synchronize_session=False)
)
# Commit the transaction after each batch
session.commit()
print_processed_batch(start_time, i, total_rows, model)


def delete_models_non_default_catalog(
session: Session, database: Database, catalog: str
) -> None:
"""
Delete models that are not in the default catalog.
This function iterates over a list of models defined by MODELS and deletes
the rows where the `catalog` column does not match the specified catalog.
Parameters:
session (Session): The SQLAlchemy session to use for database operations.
database (Database): The database instance containing the models to delete.
catalog (Catalog): The catalog to use to filter the models to delete.
"""
start_time = datetime.now()

logger.info(f"Deleting models not in the default catalog: {catalog}")

for model, column in MODELS:
# Get the total number of rows that match the condition
total_rows = (
session.query(sa.func.count(model.id))
.filter(getattr(model, column) == database.id)
.filter(model.catalog != catalog)
.scalar()
)

logger.info(
f"Total rows to be processed for {model.__tablename__}: {total_rows:,}"
)

# Update in batches using row numbers
for i in range(0, total_rows, BATCH_SIZE):
subquery = (
session.query(model.id)
.filter(getattr(model, column) == database.id)
.filter(model.catalog != catalog)
.order_by(model.id)
.offset(i)
.limit(BATCH_SIZE)
.subquery()
)
session.execute(
sa.delete(model)
.where(model.id == subquery.c.id)
.execution_options(synchronize_session=False)
)
# Commit the transaction after each batch
session.commit()
print_processed_batch(start_time, i, total_rows, model)


def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
"""
Update models and permissions when catalogs are introduced in a DB engine spec.
Expand Down Expand Up @@ -170,17 +318,7 @@ def upgrade_database_catalogs(

# update existing models that have a `catalog` column so it points to the default
# catalog
models = [
(Query, "database_id"),
(SavedQuery, "db_id"),
(TabState, "database_id"),
(TableSchema, "database_id"),
]
for model, column in models:
for instance in session.query(model).filter(
getattr(model, column) == database.id
):
instance.catalog = default_catalog
update_catalog_column(session, database, default_catalog, False)

# update `schema_perm` and `catalog_perm` for tables and charts
for table in session.query(SqlaTable).filter_by(
Expand Down Expand Up @@ -374,19 +512,7 @@ def downgrade_database_catalogs(
# permissions associated with other catalogs
downgrade_schema_perms(database, default_catalog, session)

# update existing models
models = [
(Query, "database_id"),
(SavedQuery, "db_id"),
(TabState, "database_id"),
(TableSchema, "database_id"),
]
for model, column in models:
for instance in session.query(model).filter(
getattr(model, column) == database.id,
model.catalog == default_catalog, # type: ignore
):
instance.catalog = None
update_catalog_column(session, database, default_catalog, True)

# update `schema_perm` for tables and charts
for table in session.query(SqlaTable).filter_by(
Expand All @@ -411,12 +537,7 @@ def downgrade_database_catalogs(
chart.schema_perm = schema_perm

# delete models referencing non-default catalogs
for model, column in models:
for instance in session.query(model).filter(
getattr(model, column) == database.id,
model.catalog != default_catalog, # type: ignore
):
session.delete(instance)
delete_models_non_default_catalog(session, database, default_catalog)

# delete datasets and any associated permissions
for table in session.query(SqlaTable).filter(
Expand Down

0 comments on commit 2bd9015

Please sign in to comment.