Skip to content

Commit

Permalink
Raise exception if undo is not implemented.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed Feb 22, 2024
1 parent 589eb1b commit b6c3703
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 17 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ addopts =
--doctest-modules
filterwarnings =
ignore::DeprecationWarning

ignore::UserWarning

[devpi:upload]
# Options for the devpi: PyPI server and packaging tool
Expand Down
15 changes: 8 additions & 7 deletions src/continuity/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
Transformed tensor.
"""

@abstractmethod
def undo(self, tensor: torch.Tensor) -> torch.Tensor:
"""Undoes the inverse (given the transformation is bijective).
When the transformation is not bijective (one-to-one correspondence of data), the inverse/backward
transformation is not applied. Instead, a warning should be given to the user and an appropriate approximate
inverse transformation should be provided.
"""Applies the inverse of the transformation (if it exists).
Args:
tensor: Transformed tensor.
Returns:
Tensor with the transformation undone (given it is possible).
Tensor with the transformation undone.
Raises:
NotImplementedError: If the inverse of the transformation is not implemented.
"""
raise NotImplementedError(
"The undo method is not implemented for this transform."
)
7 changes: 1 addition & 6 deletions tests/transforms/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
import torch
import warnings

from continuity.transforms import Transform

Expand Down Expand Up @@ -45,11 +44,7 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return torch.abs(tensor)

def undo(self, tensor: torch.Tensor) -> torch.Tensor:
warnings.warn(
f"The {self.__class__.__name__} transformation is not bijective. "
f"Returns the identity instead!",
stacklevel=2,
)
"""The `abs_transform` transformation is not bijective, therefore returns identity."""
return tensor

return Abs()
4 changes: 1 addition & 3 deletions tests/transforms/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,4 @@ def test_transform_undo(plus_one_transform, random_tensor):


def test_transform_undo_not_bijective(abs_transform, random_tensor):
with pytest.warns(UserWarning) as record:
assert torch.allclose(abs_transform.undo(random_tensor), random_tensor)
assert len(record) == 1
assert torch.allclose(abs_transform.undo(random_tensor), random_tensor)

0 comments on commit b6c3703

Please sign in to comment.