Skip to content

Commit

Permalink
fix: create permissions on DB import (#29802)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Aug 6, 2024
1 parent 1c3ef01 commit 61c0970
Show file tree
Hide file tree
Showing 18 changed files with 273 additions and 87 deletions.
3 changes: 2 additions & 1 deletion superset/commands/database/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from superset.commands.database.test_connection import TestConnectionDatabaseCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.db_engine_specs.base import GenericDBException
from superset.exceptions import SupersetErrorsException
from superset.extensions import event_logger, security_manager
from superset.models.core import Database
Expand Down Expand Up @@ -118,7 +119,7 @@ def run(self) -> Model:
for catalog in catalogs:
try:
self.add_schema_permissions(database, catalog, ssh_tunnel)
except Exception: # pylint: disable=broad-except
except GenericDBException: # pylint: disable=broad-except
logger.warning("Error processing catalog '%s'", catalog)
continue
except (
Expand Down
49 changes: 45 additions & 4 deletions superset/commands/database/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,55 @@ def import_database(
config["extra"] = json.dumps(config["extra"])

# Before it gets removed in import_from_dict
ssh_tunnel = config.pop("ssh_tunnel", None)
ssh_tunnel_config = config.pop("ssh_tunnel", None)

database = Database.import_from_dict(config, recursive=False)
if database.id is None:
db.session.flush()

if ssh_tunnel:
ssh_tunnel["database_id"] = database.id
SSHTunnel.import_from_dict(ssh_tunnel, recursive=False)
if ssh_tunnel_config:
ssh_tunnel_config["database_id"] = database.id
ssh_tunnel = SSHTunnel.import_from_dict(ssh_tunnel_config, recursive=False)
else:
ssh_tunnel = None

# TODO (betodealmeida): we should use the `CreateDatabaseCommand` for imports
add_permissions(database, ssh_tunnel)

return database


def add_permissions(database: Database, ssh_tunnel: SSHTunnel) -> None:
"""
Add DAR for catalogs and schemas.
"""
if database.db_engine_spec.supports_catalog:
catalogs = database.get_all_catalog_names(
cache=False,
ssh_tunnel=ssh_tunnel,
)
for catalog in catalogs:
security_manager.add_permission_view_menu(
"catalog_access",
security_manager.get_catalog_perm(
database.database_name,
catalog,
),
)
else:
catalogs = [None]

for catalog in catalogs:
for schema in database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
):
security_manager.add_permission_view_menu(
"schema_access",
security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
),
)
28 changes: 12 additions & 16 deletions superset/commands/database/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from superset.daos.database import DatabaseDAO
from superset.daos.dataset import DatasetDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.db_engine_specs.base import GenericDBException
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction

Expand Down Expand Up @@ -80,6 +81,7 @@ def run(self) -> Model:
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_catalogs(database, original_database_name, ssh_tunnel)

return database

def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None:
Expand Down Expand Up @@ -115,17 +117,13 @@ def _get_catalog_names(
) -> set[str]:
"""
Helper method to load catalogs.
This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""

try:
return database.get_all_catalog_names(
force=True,
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
except GenericDBException as ex:
raise DatabaseConnectionFailedError() from ex

def _get_schema_names(
Expand All @@ -136,18 +134,14 @@ def _get_schema_names(
) -> set[str]:
"""
Helper method to load schemas.
This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""

try:
return database.get_all_schema_names(
force=True,
catalog=catalog,
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
except GenericDBException as ex:
raise DatabaseConnectionFailedError() from ex

def _refresh_catalogs(
Expand Down Expand Up @@ -255,7 +249,7 @@ def _rename_database_in_permissions(
catalog: str | None,
schemas: set[str],
) -> None:
new_name = security_manager.get_catalog_perm(
new_catalog_perm_name = security_manager.get_catalog_perm(
database.database_name,
catalog,
)
Expand All @@ -271,10 +265,10 @@ def _rename_database_in_permissions(
perm,
)
if existing_pvm:
existing_pvm.view_menu.name = new_name
existing_pvm.view_menu.name = new_catalog_perm_name

for schema in schemas:
new_name = security_manager.get_schema_perm(
new_schema_perm_name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
Expand All @@ -291,17 +285,19 @@ def _rename_database_in_permissions(
perm,
)
if existing_pvm:
existing_pvm.view_menu.name = new_name
existing_pvm.view_menu.name = new_schema_perm_name

# rename permissions on datasets and charts
for dataset in DatabaseDAO.get_datasets(
database.id,
catalog=catalog,
schema=schema,
):
dataset.schema_perm = new_name
dataset.catalog_perm = new_catalog_perm_name
dataset.schema_perm = new_schema_perm_name
for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]:
chart.schema_perm = new_name
chart.catalog_perm = new_catalog_perm_name
chart.schema_perm = new_schema_perm_name

def validate(self) -> None:
if database_name := self._properties.get("database_name"):
Expand Down
18 changes: 18 additions & 0 deletions superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,25 @@ def get_default_catalog(
cls,
database: Database,
) -> str | None:
"""
Return the default catalog.
The default behavior for Databricks is confusing. When Unity Catalog is not
enabled we have (the DB engine spec hasn't been tested with it enabled):
> SHOW CATALOGS;
spark_catalog
> SELECT current_catalog();
hive_metastore
To handle permissions correctly we use the result of `SHOW CATALOGS` when a
single catalog is returned.
"""
with database.get_sqla_engine() as engine:
catalogs = {catalog for (catalog,) in engine.execute("SHOW CATALOGS")}
if len(catalogs) == 1:
return catalogs.pop()

return engine.execute("SELECT current_catalog()").scalar()

@classmethod
Expand Down
10 changes: 7 additions & 3 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from io import BytesIO
from unittest import mock
from unittest.mock import patch
from zipfile import is_zipfile, ZipFile

import prison
Expand Down Expand Up @@ -1768,7 +1769,8 @@ def test_export_chart_gamma(self):

assert rv.status_code == 404

def test_import_chart(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_chart(self, mock_add_permissions):
"""
Chart API: Test import chart
"""
Expand Down Expand Up @@ -1805,7 +1807,8 @@ def test_import_chart(self):
db.session.delete(database)
db.session.commit()

def test_import_chart_overwrite(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_chart_overwrite(self, mock_add_permissions):
"""
Chart API: Test import existing chart
"""
Expand Down Expand Up @@ -1876,7 +1879,8 @@ def test_import_chart_overwrite(self):
db.session.delete(database)
db.session.commit()

def test_import_chart_invalid(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_chart_invalid(self, mock_add_permissions):
"""
Chart API: Test import invalid chart
"""
Expand Down
9 changes: 6 additions & 3 deletions tests/integration_tests/charts/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def test_export_chart_command_no_related(self, mock_g):
class TestImportChartsCommand(SupersetTestCase):
@patch("superset.utils.core.g")
@patch("superset.security.manager.g")
def test_import_v1_chart(self, sm_g, utils_g) -> None:
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_v1_chart(self, mock_add_permissions, sm_g, utils_g) -> None:
"""Test that we can import a chart"""
admin = sm_g.user = utils_g.user = security_manager.find_user("admin")
contents = {
Expand Down Expand Up @@ -246,7 +247,8 @@ def test_import_v1_chart(self, sm_g, utils_g) -> None:
db.session.commit()

@patch("superset.security.manager.g")
def test_import_v1_chart_multiple(self, sm_g):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_v1_chart_multiple(self, mock_add_permissions, sm_g):
"""Test that a chart can be imported multiple times"""
sm_g.user = security_manager.find_user("admin")
contents = {
Expand All @@ -272,7 +274,8 @@ def test_import_v1_chart_multiple(self, sm_g):
db.session.delete(database)
db.session.commit()

def test_import_v1_chart_validation(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_v1_chart_validation(self, mock_add_permissions):
"""Test different validations applied when importing a chart"""
# metadata.yaml must be present
contents = {
Expand Down
13 changes: 11 additions & 2 deletions tests/integration_tests/commands_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import copy
from unittest.mock import patch

import yaml
from flask import g
Expand Down Expand Up @@ -63,8 +64,10 @@ def setUp(self):
self.user = user
setattr(g, "user", user)

def test_import_assets(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_assets(self, mock_add_permissions):
"""Test that we can import multiple assets"""

contents = {
"metadata.yaml": yaml.safe_dump(metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(database_config),
Expand Down Expand Up @@ -144,13 +147,16 @@ def test_import_assets(self):

assert dashboard.owners == [self.user]

mock_add_permissions.assert_called_with(database, None)

db.session.delete(dashboard)
db.session.delete(chart)
db.session.delete(dataset)
db.session.delete(database)
db.session.commit()

def test_import_v1_dashboard_overwrite(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_v1_dashboard_overwrite(self, mock_add_permissions):
"""Test that assets can be overwritten"""
contents = {
"metadata.yaml": yaml.safe_dump(metadata_config),
Expand Down Expand Up @@ -185,6 +191,9 @@ def test_import_v1_dashboard_overwrite(self):
chart = dashboard.slices[0]
dataset = chart.table
database = dataset.database

mock_add_permissions.assert_called_with(database, None)

db.session.delete(dashboard)
db.session.delete(chart)
db.session.delete(dataset)
Expand Down
6 changes: 4 additions & 2 deletions tests/integration_tests/dashboards/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2111,7 +2111,8 @@ def test_export_bundle_not_allowed(self):
db.session.delete(dashboard)
db.session.commit()

def test_import_dashboard(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_dashboard(self, mock_add_permissions):
"""
Dashboard API: Test import dashboard
"""
Expand Down Expand Up @@ -2215,7 +2216,8 @@ def test_import_dashboard_v0_export(self):
db.session.delete(dataset)
db.session.commit()

def test_import_dashboard_overwrite(self):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_dashboard_overwrite(self, mock_add_permissions):
"""
Dashboard API: Test import existing dashboard
"""
Expand Down
6 changes: 4 additions & 2 deletions tests/integration_tests/dashboards/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,8 @@ def test_import_v0_dashboard_cli_export(self):

@patch("superset.utils.core.g")
@patch("superset.security.manager.g")
def test_import_v1_dashboard(self, sm_g, utils_g):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_v1_dashboard(self, mock_add_permissions, sm_g, utils_g):
"""Test that we can import a dashboard"""
admin = sm_g.user = utils_g.user = security_manager.find_user("admin")
contents = {
Expand Down Expand Up @@ -583,7 +584,8 @@ def test_import_v1_dashboard(self, sm_g, utils_g):
db.session.commit()

@patch("superset.security.manager.g")
def test_import_v1_dashboard_multiple(self, mock_g):
@patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_v1_dashboard_multiple(self, mock_add_permissions, mock_g):
"""Test that a dashboard can be imported multiple times"""
mock_g.user = security_manager.find_user("admin")

Expand Down
Loading

0 comments on commit 61c0970

Please sign in to comment.