Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Sep 27, 2024
1 parent 1494c91 commit dcf4569
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 19 deletions.
8 changes: 8 additions & 0 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,4 +457,12 @@ class InsufficientPermissionsError(UserError):
"""Error thrown when the user does not have sufficient permissions."""

def __init__(self, message: str) -> None:
self.message = message
super().__init__(message)

def __reduce__(self):
# Return a tuple of class, a tuple of arguments, and optionally state
return (InsufficientPermissionsError, (self.message,))

def __str__(self):
return self.message
23 changes: 15 additions & 8 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from llmfoundry.command_utils.data_prep.convert_delta_to_json import (
InsufficientPermissionsError,
download,
fetch,
fetch_DT,
format_tablename,
iterative_combine_jsons,
Expand All @@ -30,27 +31,33 @@ class MockAnalysisException(Exception):
def __init__(self, message: str):
self.message = message

def __str__(self):
return self.message

with patch.dict('sys.modules', {'pyspark.errors': MagicMock()}):
sys.modules[
'pyspark.errors'
].AnalysisException = MockAnalysisException # pyright: ignore
].AnalysisException = MockAnalysisException # type: ignore

mock_spark = MagicMock()
mock_spark.sql.side_effect = MockAnalysisException(error_message)

with self.assertRaises(InsufficientPermissionsError) as context:
run_query(
'SELECT * FROM table',
fetch(
method='dbconnect',
cursor=None,
spark=mock_spark,
tablename='main.oogabooga',
json_output_folder='/fake/path',
batch_size=1,
processes=1,
sparkSession=mock_spark,
dbsql=None,
)

self.assertIn(
'using the schema main.oogabooga',
self.assertEqual(
str(context.exception),
error_message,
)
mock_spark.sql.assert_called_once_with('SELECT * FROM table')
mock_spark.sql.assert_called()

@patch(
'databricks.sql.connect',
Expand Down
36 changes: 25 additions & 11 deletions tests/utils/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,29 @@
def create_exception_object(
exception_class: type[foundry_exceptions.BaseContextualError],
):
# get required arg types of exception class by inspecting its __init__ method

if hasattr(inspect, 'get_annotations'):
required_args = inspect.get_annotations( # type: ignore
exception_class.__init__,
) # type: ignore
else:
required_args = exception_class.__init__.__annotations__ # python 3.9 and below

# create a dictionary of required args with default values
def get_init_annotations(cls: type):
if hasattr(inspect, 'get_annotations'):
return inspect.get_annotations(cls.__init__)
else:
return getattr(cls.__init__, '__annotations__', {})

# First, try to get annotations from the class itself
required_args = get_init_annotations(exception_class)

# If the annotations are empty, look at parent classes
if not required_args:
for parent in exception_class.__bases__:
if parent == object:
break
parent_args = get_init_annotations(parent)
if parent_args:
required_args = parent_args
break

# Remove self, return, and kwargs
required_args.pop('self', None)
required_args.pop('return', None)
required_args.pop('kwargs', None)

def get_default_value(arg_type: Optional[type] = None):
Expand Down Expand Up @@ -51,8 +64,6 @@ def get_default_value(arg_type: Optional[type] = None):
return [{'key': 'value'}]
raise ValueError(f'Unsupported arg type: {arg_type}')

required_args.pop('self', None)
required_args.pop('return', None)
kwargs = {
arg: get_default_value(arg_type)
for arg, arg_type in required_args.items()
Expand Down Expand Up @@ -80,6 +91,7 @@ def filter_exceptions(possible_exceptions: list[str]):
def test_exception_serialization(
exception_class: type[foundry_exceptions.BaseContextualError],
):
print(f'Testing serialization for {exception_class.__name__}')
excluded_base_classes = [
foundry_exceptions.InternalError,
foundry_exceptions.UserError,
Expand All @@ -88,13 +100,15 @@ def test_exception_serialization(
]

exception = create_exception_object(exception_class)
print(f'Created exception object: {exception}')

expect_reduce_error = exception.__class__ in excluded_base_classes
error_context = pytest.raises(
NotImplementedError,
) if expect_reduce_error else contextlib.nullcontext()

exc_str = str(exception)
print(f'Exception string: {exc_str}')
with error_context:
pkl = pickle.dumps(exception)
unpickled_exc = pickle.loads(pkl)
Expand Down

0 comments on commit dcf4569

Please sign in to comment.