Skip to content

Commit

Permalink
Merge branch 'main' into release-docker-img
Browse files Browse the repository at this point in the history
  • Loading branch information
KuuCi committed Sep 25, 2024
2 parents 93d9840 + c786def commit c423547
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 5 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

"""The LLM Foundry Version."""

__version__ = '0.12.0.dev0'
__version__ = '0.13.0.dev0'
24 changes: 23 additions & 1 deletion llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,27 @@ def run_query(
elif method == 'dbconnect':
if spark == None:
raise ValueError(f'sparkSession is required for dbconnect')
df = spark.sql(query)

try:
df = spark.sql(query)
except Exception as e:
from pyspark.errors import AnalysisException
if isinstance(e, AnalysisException):
if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore
match = re.search(
r"Schema\s+'([^']+)'",
e.message, # pyright: ignore
)
if match:
schema_name = match.group(1)
action = f'using the schema {schema_name}'
else:
action = 'using the schema'
raise InsufficientPermissionsError(action=action,) from e
raise RuntimeError(
f'Error in querying into schema. Restart sparkSession and try again',
) from e

if collect:
return df.collect()
return df
Expand Down Expand Up @@ -461,6 +481,8 @@ def fetch(
raise InsufficientPermissionsError(
action=f'reading from {tablename}',
) from e
if isinstance(e, InsufficientPermissionsError):
raise e
raise RuntimeError(
f'Error in get rows from {tablename}. Restart sparkSession and try again',
) from e
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/command_utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def evaluate_model(
warnings.warn(
VersionedDeprecationWarning(
'The argument fsdp_config is deprecated. Please use parallelism_config instead.',
remove_version='0.13.0',
remove_version='0.14.0',
),
)
if fsdp_config and parallelism_config:
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/hf/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
warnings.warn(
VersionedDeprecationWarning(
'`HuggingFaceModelWithFSDP` is deprecated. In the future please use `BaseHuggingFaceModel`.',
remove_version='0.13.0',
remove_version='0.14.0',
),
)
super().__init__(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
'mlflow>=2.14.1,<2.17',
'accelerate>=0.25,<0.34', # for HF inference `device_map`
'transformers>=4.43.2,<4.44',
'mosaicml-streaming>=0.8.1,<0.9',
'mosaicml-streaming>=0.9.0,<0.10',
'torch>=2.4.0,<2.4.1',
'datasets>=2.19,<2.20',
'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data
Expand Down
35 changes: 35 additions & 0 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import sys
import unittest
from argparse import Namespace
from typing import Any
from unittest.mock import MagicMock, mock_open, patch

from llmfoundry.command_utils.data_prep.convert_delta_to_json import (
InsufficientPermissionsError,
download,
fetch_DT,
format_tablename,
Expand All @@ -17,6 +19,39 @@

class TestConvertDeltaToJsonl(unittest.TestCase):

def test_run_query_dbconnect_insufficient_permissions(self):
error_message = (
'[INSUFFICIENT_PERMISSIONS] Insufficient privileges: User does not have USE SCHEMA '
"on Schema 'main.oogabooga'. SQLSTATE: 42501"
)

class MockAnalysisException(Exception):

def __init__(self, message: str):
self.message = message

with patch.dict('sys.modules', {'pyspark.errors': MagicMock()}):
sys.modules[
'pyspark.errors'
].AnalysisException = MockAnalysisException # pyright: ignore

mock_spark = MagicMock()
mock_spark.sql.side_effect = MockAnalysisException(error_message)

with self.assertRaises(InsufficientPermissionsError) as context:
run_query(
'SELECT * FROM table',
method='dbconnect',
cursor=None,
spark=mock_spark,
)

self.assertIn(
'using the schema main.oogabooga',
str(context.exception),
)
mock_spark.sql.assert_called_once_with('SELECT * FROM table')

@patch(
'databricks.sql.connect',
)
Expand Down

0 comments on commit c423547

Please sign in to comment.