Skip to content

Commit

Permalink
Ensure that transformed masks are contiguous arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Björn Barz committed Oct 9, 2024
1 parent 1ad56dc commit d31a905
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
13 changes: 4 additions & 9 deletions albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
ColorType,
Targets,
)
from .utils import format_args
from .utils import ensure_contiguous_output, format_args

__all__ = ["BasicTransform", "DualTransform", "ImageOnlyTransform", "NoOp"]

Expand Down Expand Up @@ -141,14 +141,9 @@ def apply_with_params(self, params: dict[str, Any], *args: Any, **kwargs: Any) -
for key, arg in kwargs.items():
if key in self._key2func and arg is not None:
target_function = self._key2func[key]
if isinstance(arg, np.ndarray):
result = target_function(np.require(arg, requirements=["C_CONTIGUOUS"]), **params)
if isinstance(result, np.ndarray):
res[key] = np.require(result, requirements=["C_CONTIGUOUS"])
else:
res[key] = result
else:
res[key] = target_function(arg, **params)
res[key] = ensure_contiguous_output(
target_function(ensure_contiguous_output(arg), **params),
)
else:
res[key] = arg
return res
Expand Down
8 changes: 8 additions & 0 deletions albumentations/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ def ensure_int_output(
return (int(min_val), int(max_val)) if isinstance(param, int) else (float(min_val), float(max_val))


def ensure_contiguous_output(arg: np.ndarray | Sequence[np.ndarray]) -> np.ndarray | list[np.ndarray]:
if isinstance(arg, np.ndarray):
arg = np.require(arg, requirements=["C_CONTIGUOUS"])
elif isinstance(arg, Sequence):
arg = list(map(ensure_contiguous_output, arg))
return arg


@overload
def to_tuple(param: ScaleIntType, low: ScaleType | None = None, bias: ScalarType | None = None) -> tuple[int, int]: ...

Expand Down
9 changes: 6 additions & 3 deletions tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ def test_to_tensor_v2_on_non_contiguous_array_with_random_rotate90():

img = np.random.randint(0, 256, (640, 480, 3)).astype(np.uint8)
masks = [np.random.randint(0, 2, (640, 480)).astype(np.uint8) for _ in range(4)]
transformed = transforms(image=img, masks=masks)
assert transformed["image"].numpy().shape == (3, 640, 480)
assert transformed["masks"][0].shape == (640, 480)
for _ in range(10):
transformed = transforms(image=img, masks=masks)
assert isinstance(transformed["image"], torch.Tensor)
assert isinstance(transformed["masks"][0], torch.Tensor)
assert transformed["image"].numpy().shape in ((3, 640, 480), (3, 480, 640))
assert transformed["masks"][0].shape in ((640, 480), (480, 640))

0 comments on commit d31a905

Please sign in to comment.