Skip to content

Commit

Permalink
feat(test): added assertion of same type in the backend test function
Browse files Browse the repository at this point in the history
  • Loading branch information
sherry30 committed Sep 15, 2023
1 parent 190db93 commit a5da9e1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
17 changes: 17 additions & 0 deletions ivy_tests/test_ivy/helpers/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,23 @@ def assert_same_type_and_shape(values, this_key_chain=None):
), "returned dtype = {}, ground-truth returned dtype = {}".format(x_d, y_d)


def assert_same_type(ret_from_target, ret_from_gt, backend_to_test, gt_backend):
"""
Assert that the return types from the target and ground truth frameworks are the
same.
checks with a string comparison because with_backend returns
different objects. Doesn't check recursively.
"""
# ToDo: do this with nested map
assert_msg = (
f"ground truth backend ({gt_backend}) returned"
f" {type(ret_from_gt)} but target backend ({backend_to_test}) returned"
f" {type(ret_from_target)}"
)
assert str(type(ret_from_target)) == str(type(ret_from_gt)), assert_msg


def value_test(
*,
ret_np_flat,
Expand Down
26 changes: 15 additions & 11 deletions ivy_tests/test_ivy/helpers/function_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ivy_tests.test_ivy.helpers.testing_helpers import _create_transpile_report
from .assertions import (
value_test,
assert_same_type,
check_unsupported_dtype,
)

Expand Down Expand Up @@ -567,17 +568,6 @@ def test_function(
on_device=on_device,
)

assert ret_device == ret_from_gt_device, (
f"ground truth backend ({test_flags.ground_truth_backend}) returned array on"
f" device {ret_from_gt_device} but target backend ({backend_to_test})"
f" returned array on device {ret_device}"
)
if ret_device is not None:
assert ret_device == on_device, (
f"device is set to {on_device}, but ground truth produced array on"
f" {ret_device}"
)

# assuming value test will be handled manually in the test function
if not test_values:
if return_flat_np_arrays:
Expand All @@ -599,6 +589,20 @@ def test_function(
backend=backend_to_test,
ground_truth_backend=test_flags.ground_truth_backend,
)
assert_same_type(
ret_from_target, ret_from_gt, backend_to_test, test_flags.ground_truth_backend
)

assert ret_device == ret_from_gt_device, (
f"ground truth backend ({test_flags.ground_truth_backend}) returned array on"
f" device {ret_from_gt_device} but target backend ({backend_to_test})"
f" returned array on device {ret_device}"
)
if ret_device is not None:
assert ret_device == on_device, (
f"device is set to {on_device}, but ground truth produced array on"
f" {ret_device}"
)


def test_frontend_function(
Expand Down

0 comments on commit a5da9e1

Please sign in to comment.