Skip to content

Commit

Permalink
Get correct backend dtypes when prune_function==False and intersect w…
Browse files Browse the repository at this point in the history
…ith frontend dtypes for frontend tests.
  • Loading branch information
ReneFabricius committed Sep 14, 2023
1 parent 4de4166 commit 44ac5bb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 0 additions & 2 deletions ivy_tests/test_ivy/helpers/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,13 @@ def setup_frontend_test(frontend: str, backend: str, device: str, test_data: Tes
_set_frontend(frontend)
_set_backend(backend)
_set_device(device)
_set_ground_truth_backend(frontend)


def teardown_frontend_test():
_unset_test_data()
_unset_frontend()
_unset_backend()
_unset_device()
_unset_ground_truth_backend()


def _set_test_data(test_data: TestData):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _get_type_dict_helper(framework, kind, is_frontend_test):
if is_frontend_test:
framework_module = get_frontend_config(framework).supported_dtypes
else:
framework_module = ivy
framework_module = ivy.with_backend(framework)

if kind == "valid":
return framework_module.valid_dtypes
Expand Down Expand Up @@ -138,7 +138,7 @@ def get_dtypes(
function as the keyword argument with the given name.
prune_function
if True, the function will prune the data types to only include the ones that
are supported by the current backend. If False, the function will return all
are supported by the current function. If False, the function will return all
the data types supported by the current backend.
Returns
Expand Down Expand Up @@ -225,6 +225,9 @@ def get_dtypes(
# FN_DTYPES & BACKEND_DTYPES & FRONTEND_DTYPES & GROUND_TRUTH_DTYPES

# If being called from a frontend test
if test_globals.CURRENT_FRONTEND is not test_globals._Notsetval:
frontend_dtypes = _get_type_dict(test_globals.CURRENT_FRONTEND, kind, True)
valid_dtypes = valid_dtypes.intersection(frontend_dtypes)

# Make sure we return dtypes that are compatible with ground truth backend
ground_truth_is_set = (
Expand Down

0 comments on commit 44ac5bb

Please sign in to comment.