diff --git a/albumentations/augmentations/mixing/transforms.py b/albumentations/augmentations/mixing/transforms.py index 7edfb171b..639c04f76 100644 --- a/albumentations/augmentations/mixing/transforms.py +++ b/albumentations/augmentations/mixing/transforms.py @@ -1,12 +1,12 @@ import random -from typing import Any, Callable, Dict, Generator, Iterator, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Sequence, Tuple, Union from warnings import warn import numpy as np from albumentations.augmentations.utils import is_grayscale_image from albumentations.core.transforms_interface import ReferenceBasedTransform -from albumentations.core.types import ReferenceImage, Targets +from albumentations.core.types import BoxType, KeypointType, ReferenceImage, Targets from albumentations.random_utils import beta from .functional import mix_arrays @@ -27,19 +27,15 @@ class MixUp(ReferenceBasedTransform): In International Conference on Learning Representations. https://arxiv.org/abs/1710.09412 Args: - ---- - reference_data (Optional[Union[Generator[ReferenceImage, None, None], Sequence[ReferenceImage]]]): - A sequence or generator of dictionaries containing the reference data for mixing. Each dictionary - should contain: - - 'image': Mandatory key with an image array. - - 'mask': Optional key with a mask array. - - 'global_label': Optional key with a class label array. - - 'keypoints': Optional key with a list of keypoints. + reference_data (Optional[Union[Generator[ReferenceImage, None, None], Sequence[Any]]]): + A sequence or generator of dictionaries containing the reference data for mixing If None or an empty sequence is provided, no operation is performed and a warning is issued. read_fn (Callable[[ReferenceImage], Dict[str, Any]]): - A function to process items from reference_data. It should accept a dictionary from reference_data - and return a processed dictionary containing 'image', and optionally 'mask', 'keypoints', 'global_label', - each as numpy arrays. Defaults to a no-op lambda function. + A function to process items from reference_data. It should accept items from reference_data + and return a dictionary containing processed data: + - The returned dictionary must include an 'image' key with a numpy array value. + - It may also include 'mask', 'global_label' each associated with numpy array values. + Defaults to a function that assumes input dictionary contains numpy arrays and directly returns it. alpha (float): The alpha parameter for the Beta distribution, influencing the mix's balance. Must be ≥ 0. Higher values lead to more uniform mixing. Defaults to 0.4. @@ -53,16 +49,12 @@ class MixUp(ReferenceBasedTransform): - uint8, float32 Raises: - ------ - ValueError: If the alpha parameter is negative. - - NotImplementedError: If the transform is applied to bounding boxes. - - NotImplementedError: If the transform is applied to keypoints. + - NotImplementedError: If the transform is applied to bounding boxes or keypoints. Notes: - ----- - If no reference data is provided, a warning is issued, and the transform acts as a no-op. - - + - Notes if images are in float32 format, they should be within [0, 1] range. """ _targets = (Targets.IMAGE, Targets.MASK, Targets.GLOBAL_LABEL) @@ -87,7 +79,11 @@ def __init__( if reference_data is None: warn("No reference data provided for MixUp. This transform will act as a no-op.") # Create an empty generator - self.reference_data = reference_data or [] + elif isinstance(reference_data, Iterable) and not isinstance(reference_data, str): + self.reference_data = reference_data + else: + msg = "reference_data must be a list, tuple, generator, or None." + raise TypeError(msg) def apply(self, img: np.ndarray, mix_data: ReferenceImage, mix_coef: float, **params: Any) -> np.ndarray: mix_img = mix_data.get("image") @@ -110,6 +106,16 @@ def apply_to_global_label( return mix_coef * label + (1 - mix_coef) * mix_label return label + def apply_to_bboxes(self, bboxes: Sequence[BoxType], mix_data: ReferenceImage, **params: Any) -> Sequence[BoxType]: + msg = "MixUp does not support bounding boxes yet, feel free to submit pull request to https://github.com/albumentations-team/albumentations/." + raise NotImplementedError(msg) + + def apply_to_keypoints( + self, keypoints: Sequence[KeypointType], *args: Any, **params: Any + ) -> Sequence[KeypointType]: + msg = "MixUp does not support keypoints yet, feel free to submit pull request to https://github.com/albumentations-team/albumentations/." + raise NotImplementedError(msg) + def get_transform_init_args_names(self) -> Tuple[str, ...]: return "reference_data", "alpha" diff --git a/albumentations/core/transforms_interface.py b/albumentations/core/transforms_interface.py index e2635520d..8d99bcd78 100644 --- a/albumentations/core/transforms_interface.py +++ b/albumentations/core/transforms_interface.py @@ -13,7 +13,6 @@ ColorType, KeypointInternalType, KeypointType, - ReferenceImage, ScalarType, ScaleType, Targets, @@ -260,9 +259,6 @@ class DualTransform(BasicTransform): apply_to_keypoint(keypoint: KeypointInternalType, *args: Any, **params: Any) -> KeypointInternalType: Applies the transform to a single keypoint. Should be implemented in the subclass. - apply_to_global_label(label: np.ndarray, *args: Any, **params: Any) -> np.ndarray: - Applies the transform to a single label. Should be implemented in the subclass. - apply_to_bboxes(bboxes: Sequence[BoxType], *args: Any, **params: Any) -> Sequence[BoxType]: Applies the transform to a list of bounding boxes. Delegates to `apply_to_bbox` for each bounding box. @@ -275,9 +271,6 @@ class DualTransform(BasicTransform): apply_to_masks(masks: Sequence[np.ndarray], **params: Any) -> List[np.ndarray]: Applies the transform to a list of masks. Delegates to `apply_to_mask` for each mask. - apply_to_global_labels(labels: Sequence[np.ndarray], **params: Any) -> List[np.ndarray]: - Applies the transform to a list of labels. Delegates to `apply_to_label` for each label. - Note: ---- This class is intended to be subclassed and should not be used directly. Subclasses are expected to @@ -294,8 +287,6 @@ def targets(self) -> Dict[str, Callable[..., Any]]: "masks": self.apply_to_masks, "bboxes": self.apply_to_bboxes, "keypoints": self.apply_to_keypoints, - "global_label": self.apply_to_global_label, - "global_labels": self.apply_to_global_labels, } def apply_to_bbox(self, bbox: BoxInternalType, *args: Any, **params: Any) -> BoxInternalType: @@ -374,12 +365,12 @@ def get_transform_init_args_names(self) -> Tuple[str, ...]: class ReferenceBasedTransform(DualTransform): - def apply_to_bboxes(self, bboxes: Sequence[BoxType], mix_data: ReferenceImage, **params: Any) -> Sequence[BoxType]: - msg = "Transform does not support bounding boxes yet, feel free to submit pull request to https://github.com/albumentations-team/albumentations/." - raise NotImplementedError(msg) - - def apply_to_keypoints( - self, keypoints: Sequence[KeypointType], *args: Any, **params: Any - ) -> Sequence[KeypointType]: - msg = "Transform does not support keypoints yet, feel free to submit pull request to https://github.com/albumentations-team/albumentations/." - raise NotImplementedError(msg) + @property + def targets(self) -> Dict[str, Callable[..., Any]]: + return { + "global_label": self.apply_to_global_label, + "image": self.apply, + "mask": self.apply_to_mask, + "bboxes": self.apply_to_bboxes, + "keypoints": self.apply_to_keypoints, + } diff --git a/tests/test_mixing.py b/tests/test_mixing.py index d8d0e41dc..32375c65b 100644 --- a/tests/test_mixing.py +++ b/tests/test_mixing.py @@ -35,7 +35,6 @@ def test_image_only(augmentation_cls, params, image): data = aug(image=image) assert data["image"].dtype == np.uint8 - @pytest.mark.parametrize( ["augmentation_cls", "params"], [(A.MixUp, { @@ -141,3 +140,22 @@ def test_keypoint_error(image, mask, global_label, keypoints): with pytest.raises(NotImplementedError): aug(image=image, global_label=global_label, mask=mask, keypoints=keypoints) + + +@pytest.mark.parametrize( ["augmentation_cls", "params"], [(A.CLAHE, {"p": 1}), (A.HorizontalFlip, {"p": 1})]) +def test_pipeline(augmentation_cls, params, image, mask, global_label): + reference_data =[{"image": np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8), + "mask": np.random.randint(0, 256, (100, 100, 1), dtype=np.uint8), + "global_label": np.array([0, 0, 1])}] + + mix_up = A.MixUp(p=1, reference_data=reference_data, read_fn=lambda x: x) + + aug = A.Compose([augmentation_cls(**params), mix_up], p=1) + + data = aug(image=image, global_label=global_label, mask=mask) + + assert data["image"].dtype == np.uint8 + + mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_data[0]["global_label"]) + + assert 0 <= mix_coeff_label <= 1