diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 666d0278c6..d676fc2165 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -233,7 +233,27 @@ def run_query( elif method == 'dbconnect': if spark == None: raise ValueError(f'sparkSession is required for dbconnect') - df = spark.sql(query) + + try: + df = spark.sql(query) + except Exception as e: + from pyspark.errors import AnalysisException + if isinstance(e, AnalysisException): + if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore + match = re.search( + r"Schema\s+'([^']+)'", + e.message, # pyright: ignore + ) + if match: + schema_name = match.group(1) + action = f'using the schema {schema_name}' + else: + action = 'using the schema' + raise InsufficientPermissionsError(action=action,) from e + raise RuntimeError( + f'Error in querying into schema. Restart sparkSession and try again', + ) from e + if collect: return df.collect() return df @@ -461,6 +481,8 @@ def fetch( raise InsufficientPermissionsError( action=f'reading from {tablename}', ) from e + if isinstance(e, InsufficientPermissionsError): + raise e raise RuntimeError( f'Error in get rows from {tablename}. Restart sparkSession and try again', ) from e diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index e623467bf7..bbb03a26d9 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -1,12 +1,14 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import sys import unittest from argparse import Namespace from typing import Any from unittest.mock import MagicMock, mock_open, patch from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( + InsufficientPermissionsError, download, fetch_DT, format_tablename, @@ -17,6 +19,39 @@ class TestConvertDeltaToJsonl(unittest.TestCase): + def test_run_query_dbconnect_insufficient_permissions(self): + error_message = ( + '[INSUFFICIENT_PERMISSIONS] Insufficient privileges: User does not have USE SCHEMA ' + "on Schema 'main.oogabooga'. SQLSTATE: 42501" + ) + + class MockAnalysisException(Exception): + + def __init__(self, message: str): + self.message = message + + with patch.dict('sys.modules', {'pyspark.errors': MagicMock()}): + sys.modules[ + 'pyspark.errors' + ].AnalysisException = MockAnalysisException # pyright: ignore + + mock_spark = MagicMock() + mock_spark.sql.side_effect = MockAnalysisException(error_message) + + with self.assertRaises(InsufficientPermissionsError) as context: + run_query( + 'SELECT * FROM table', + method='dbconnect', + cursor=None, + spark=mock_spark, + ) + + self.assertIn( + 'using the schema main.oogabooga', + str(context.exception), + ) + mock_spark.sql.assert_called_once_with('SELECT * FROM table') + @patch( 'databricks.sql.connect', )