Skip to content

Commit

Permalink
Fixes (#1557)
Browse files Browse the repository at this point in the history
* Fixes

* Fixes
  • Loading branch information
ternaus authored Mar 4, 2024
1 parent 138bc5c commit 6494ef3
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 39 deletions.
46 changes: 26 additions & 20 deletions albumentations/augmentations/mixing/transforms.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import random
from typing import Any, Callable, Dict, Generator, Iterator, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Sequence, Tuple, Union
from warnings import warn

import numpy as np

from albumentations.augmentations.utils import is_grayscale_image
from albumentations.core.transforms_interface import ReferenceBasedTransform
from albumentations.core.types import ReferenceImage, Targets
from albumentations.core.types import BoxType, KeypointType, ReferenceImage, Targets
from albumentations.random_utils import beta

from .functional import mix_arrays
Expand All @@ -27,19 +27,15 @@ class MixUp(ReferenceBasedTransform):
In International Conference on Learning Representations. https://arxiv.org/abs/1710.09412
Args:
----
reference_data (Optional[Union[Generator[ReferenceImage, None, None], Sequence[ReferenceImage]]]):
A sequence or generator of dictionaries containing the reference data for mixing. Each dictionary
should contain:
- 'image': Mandatory key with an image array.
- 'mask': Optional key with a mask array.
- 'global_label': Optional key with a class label array.
- 'keypoints': Optional key with a list of keypoints.
reference_data (Optional[Union[Generator[ReferenceImage, None, None], Sequence[Any]]]):
A sequence or generator of dictionaries containing the reference data for mixing
If None or an empty sequence is provided, no operation is performed and a warning is issued.
read_fn (Callable[[ReferenceImage], Dict[str, Any]]):
A function to process items from reference_data. It should accept a dictionary from reference_data
and return a processed dictionary containing 'image', and optionally 'mask', 'keypoints', 'global_label',
each as numpy arrays. Defaults to a no-op lambda function.
A function to process items from reference_data. It should accept items from reference_data
and return a dictionary containing processed data:
- The returned dictionary must include an 'image' key with a numpy array value.
- It may also include 'mask', 'global_label' each associated with numpy array values.
Defaults to a function that assumes input dictionary contains numpy arrays and directly returns it.
alpha (float):
The alpha parameter for the Beta distribution, influencing the mix's balance. Must be ≥ 0.
Higher values lead to more uniform mixing. Defaults to 0.4.
Expand All @@ -53,16 +49,12 @@ class MixUp(ReferenceBasedTransform):
- uint8, float32
Raises:
------
- ValueError: If the alpha parameter is negative.
- NotImplementedError: If the transform is applied to bounding boxes.
- NotImplementedError: If the transform is applied to keypoints.
- NotImplementedError: If the transform is applied to bounding boxes or keypoints.
Notes:
-----
- If no reference data is provided, a warning is issued, and the transform acts as a no-op.
- Notes if images are in float32 format, they should be within [0, 1] range.
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.GLOBAL_LABEL)
Expand All @@ -87,7 +79,11 @@ def __init__(
if reference_data is None:
warn("No reference data provided for MixUp. This transform will act as a no-op.")
# Create an empty generator
self.reference_data = reference_data or []
elif isinstance(reference_data, Iterable) and not isinstance(reference_data, str):
self.reference_data = reference_data
else:
msg = "reference_data must be a list, tuple, generator, or None."
raise TypeError(msg)

def apply(self, img: np.ndarray, mix_data: ReferenceImage, mix_coef: float, **params: Any) -> np.ndarray:
mix_img = mix_data.get("image")
Expand All @@ -110,6 +106,16 @@ def apply_to_global_label(
return mix_coef * label + (1 - mix_coef) * mix_label
return label

def apply_to_bboxes(self, bboxes: Sequence[BoxType], mix_data: ReferenceImage, **params: Any) -> Sequence[BoxType]:
msg = "MixUp does not support bounding boxes yet, feel free to submit pull request to https://github.com/albumentations-team/albumentations/."
raise NotImplementedError(msg)

def apply_to_keypoints(
self, keypoints: Sequence[KeypointType], *args: Any, **params: Any
) -> Sequence[KeypointType]:
msg = "MixUp does not support keypoints yet, feel free to submit pull request to https://github.com/albumentations-team/albumentations/."
raise NotImplementedError(msg)

def get_transform_init_args_names(self) -> Tuple[str, ...]:
return "reference_data", "alpha"

Expand Down
27 changes: 9 additions & 18 deletions albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
ColorType,
KeypointInternalType,
KeypointType,
ReferenceImage,
ScalarType,
ScaleType,
Targets,
Expand Down Expand Up @@ -260,9 +259,6 @@ class DualTransform(BasicTransform):
apply_to_keypoint(keypoint: KeypointInternalType, *args: Any, **params: Any) -> KeypointInternalType:
Applies the transform to a single keypoint. Should be implemented in the subclass.
apply_to_global_label(label: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
Applies the transform to a single label. Should be implemented in the subclass.
apply_to_bboxes(bboxes: Sequence[BoxType], *args: Any, **params: Any) -> Sequence[BoxType]:
Applies the transform to a list of bounding boxes. Delegates to `apply_to_bbox` for each bounding box.
Expand All @@ -275,9 +271,6 @@ class DualTransform(BasicTransform):
apply_to_masks(masks: Sequence[np.ndarray], **params: Any) -> List[np.ndarray]:
Applies the transform to a list of masks. Delegates to `apply_to_mask` for each mask.
apply_to_global_labels(labels: Sequence[np.ndarray], **params: Any) -> List[np.ndarray]:
Applies the transform to a list of labels. Delegates to `apply_to_label` for each label.
Note:
----
This class is intended to be subclassed and should not be used directly. Subclasses are expected to
Expand All @@ -294,8 +287,6 @@ def targets(self) -> Dict[str, Callable[..., Any]]:
"masks": self.apply_to_masks,
"bboxes": self.apply_to_bboxes,
"keypoints": self.apply_to_keypoints,
"global_label": self.apply_to_global_label,
"global_labels": self.apply_to_global_labels,
}

def apply_to_bbox(self, bbox: BoxInternalType, *args: Any, **params: Any) -> BoxInternalType:
Expand Down Expand Up @@ -374,12 +365,12 @@ def get_transform_init_args_names(self) -> Tuple[str, ...]:


class ReferenceBasedTransform(DualTransform):
def apply_to_bboxes(self, bboxes: Sequence[BoxType], mix_data: ReferenceImage, **params: Any) -> Sequence[BoxType]:
msg = "Transform does not support bounding boxes yet, feel free to submit pull request to https://github.com/albumentations-team/albumentations/."
raise NotImplementedError(msg)

def apply_to_keypoints(
self, keypoints: Sequence[KeypointType], *args: Any, **params: Any
) -> Sequence[KeypointType]:
msg = "Transform does not support keypoints yet, feel free to submit pull request to https://github.com/albumentations-team/albumentations/."
raise NotImplementedError(msg)
@property
def targets(self) -> Dict[str, Callable[..., Any]]:
return {
"global_label": self.apply_to_global_label,
"image": self.apply,
"mask": self.apply_to_mask,
"bboxes": self.apply_to_bboxes,
"keypoints": self.apply_to_keypoints,
}
20 changes: 19 additions & 1 deletion tests/test_mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def test_image_only(augmentation_cls, params, image):
data = aug(image=image)
assert data["image"].dtype == np.uint8


@pytest.mark.parametrize(
["augmentation_cls", "params"],
[(A.MixUp, {
Expand Down Expand Up @@ -141,3 +140,22 @@ def test_keypoint_error(image, mask, global_label, keypoints):

with pytest.raises(NotImplementedError):
aug(image=image, global_label=global_label, mask=mask, keypoints=keypoints)


@pytest.mark.parametrize( ["augmentation_cls", "params"], [(A.CLAHE, {"p": 1}), (A.HorizontalFlip, {"p": 1})])
def test_pipeline(augmentation_cls, params, image, mask, global_label):
reference_data =[{"image": np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8),
"mask": np.random.randint(0, 256, (100, 100, 1), dtype=np.uint8),
"global_label": np.array([0, 0, 1])}]

mix_up = A.MixUp(p=1, reference_data=reference_data, read_fn=lambda x: x)

aug = A.Compose([augmentation_cls(**params), mix_up], p=1)

data = aug(image=image, global_label=global_label, mask=mask)

assert data["image"].dtype == np.uint8

mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_data[0]["global_label"])

assert 0 <= mix_coeff_label <= 1

0 comments on commit 6494ef3

Please sign in to comment.