diff --git a/tests/transforms/test_transform.py b/tests/transforms/test_transform.py index 4bc5af56..536a18f3 100644 --- a/tests/transforms/test_transform.py +++ b/tests/transforms/test_transform.py @@ -23,4 +23,6 @@ def test_transform_undo(plus_one_transform, random_tensor): def test_transform_undo_not_bijective(abs_transform, random_tensor): - assert torch.allclose(abs_transform.undo(random_tensor), random_tensor) + with pytest.warns(UserWarning) as record: + assert torch.allclose(abs_transform.undo(random_tensor), random_tensor) + assert len(record) == 1