Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

temp commit #1580

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions albumentations/augmentations/crops/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,16 +224,13 @@ def bbox_crop(


def clamping_crop(img: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int) -> np.ndarray:
h, w = img.shape[:2]
if x_min < 0:
x_min = 0
if y_min < 0:
y_min = 0
if y_max >= h:
y_max = h - 1
if x_max >= w:
x_max = w - 1
return img[int(y_min) : int(y_max), int(x_min) : int(x_max)]
height, width = img.shape[:2]
x_min = max(0, x_min)
y_min = max(0, y_min)
x_max = min(width, x_max + 1) # +1 because slice indices are non-inclusive at the top
y_max = min(height, y_max + 1) # +1 for the same reason

return img[y_min:y_max, x_min:x_max]


@preserve_channel_dim
Expand Down
260 changes: 252 additions & 8 deletions albumentations/augmentations/mixing/transforms.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
import random
import types
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Sequence, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from warnings import warn

import numpy as np

from albumentations.augmentations.crops.transforms import RandomSizedBBoxSafeCrop
from albumentations.augmentations.functional import split_uniform_grid
from albumentations.augmentations.geometric.resize import LongestMaxSize, Resize, SmallestMaxSize
from albumentations.augmentations.geometric.transforms import PadIfNeeded
from albumentations.augmentations.utils import is_grayscale_image
from albumentations.core.composition import Compose
from albumentations.core.transforms_interface import ReferenceBasedTransform
from albumentations.core.types import BoxType, KeypointType, ReferenceImage, Targets
from albumentations.random_utils import beta
from albumentations.random_utils import beta, choice, shuffle

from .functional import mix_arrays

__all__ = ["MixUp"]
__all__ = ["MixUp", "Mosaic"]


class MixUp(ReferenceBasedTransform):
Expand All @@ -28,10 +47,10 @@ class MixUp(ReferenceBasedTransform):
In International Conference on Learning Representations. https://arxiv.org/abs/1710.09412

Args:
reference_data (Optional[Union[Generator[ReferenceImage, None, None], Sequence[Any]]]):
reference_data (Optional[Union[Generator[Any, None, None], Sequence[Any]]]):
A sequence or generator of dictionaries containing the reference data for mixing
If None or an empty sequence is provided, no operation is performed and a warning is issued.
read_fn (Callable[[ReferenceImage], Dict[str, Any]]):
read_fn (Callable[[Any], ReferenceImage]):
A function to process items from reference_data. It should accept items from reference_data
and return a dictionary containing processed data:
- The returned dictionary must include an 'image' key with a numpy array value.
Expand Down Expand Up @@ -109,8 +128,8 @@ class MixUp(ReferenceBasedTransform):

def __init__(
self,
reference_data: Optional[Union[Generator[ReferenceImage, None, None], Sequence[Any]]] = None,
read_fn: Callable[[ReferenceImage], Any] = lambda x: {"image": x, "mask": None, "class_label": None},
reference_data: Optional[Union[Generator[Any, None, None], Sequence[Any]]] = None,
read_fn: Callable[[Any], ReferenceImage] = lambda x: {"image": x, "mask": None, "global_label": None},
alpha: float = 0.4,
mix_coef_return_name: str = "mix_coef",
always_apply: bool = False,
Expand Down Expand Up @@ -174,7 +193,7 @@ def apply_to_keypoints(
def get_transform_init_args_names(self) -> Tuple[str, ...]:
return "reference_data", "alpha"

def get_params(self) -> Dict[str, Union[None, float, Dict[str, Any]]]:
def get_params(self) -> Dict[str, Any]:
mix_data = None
# Check if reference_data is not empty and is a sequence (list, tuple, np.array)
if isinstance(self.reference_data, Sequence) and not isinstance(self.reference_data, (str, bytes)):
Expand Down Expand Up @@ -206,3 +225,228 @@ def apply_with_params(self, params: Dict[str, Any], *args: Any, **kwargs: Any) -
if self.mix_coef_return_name:
res[self.mix_coef_return_name] = params["mix_coef"]
return res


class Mosaic(ReferenceBasedTransform):
"""Performs Mosaic data augmentation, combining multiple images into a single image for enhanced model training.

This transformation creates a composite image from multiple source images arranged in a grid, which can be uniform
or random based on the `split_mode`. The mosaic augmentation introduces variations in context, scale, and object
combinations, beneficial for object detection models.

Args:
reference_data (Optional[Union[Generator[ReferenceImage, None, None], Sequence[Any]]]):
A sequence or generator of dictionaries with the reference data for the mosaic.
If None or an empty sequence is provided, no operation is performed.
read_fn (Callable[[ReferenceImage], Dict[str, Any]]):
A function to process items from `reference_data`. It must return a dictionary containing 'image',
and optionally 'mask' and 'global_label', each associated with numpy array values.
grid_size (Tuple[int, int]):
The size (rows, columns) of the grid to arrange images in the mosaic. Defaults to (3, 3).
split_mode (str):
Determines how the images are split and arranged in the mosaic. Can be 'uniform' for equal-sized tiles,
or 'random' for randomly sized tiles. Defaults to 'uniform'.
preprocessing_mode (str): resize, longest_max_size_pad, smallest_max_size_crop, random_sized_bbox_safe_crop,
p (float):
The probability of applying the transformation. Defaults to 0.5.

Targets:
image, mask, global_label

Image types:
uint8, float32

Raises:
- ValueError: For invalid `grid_size` or `split_mode` values.
- NotImplementedError: If applied to bounding boxes or keypoints.
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.GLOBAL_LABEL)

def __init__(
self,
reference_data: Optional[Union[Generator[ReferenceImage, None, None], Sequence[Any]]] = None,
read_fn: Callable[[Any], ReferenceImage] = lambda x: {
"image": x,
"mask": None,
"global_label": None,
},
grid_size: Tuple[int, int] = (3, 3),
split_mode: str = "uniform",
preprocessing_mode: str = "resize",
target_size: Tuple[int, int] = (1024, 1024),
always_apply: bool = False,
p: float = 0.5,
):
super().__init__(always_apply, p)
self.grid_size = grid_size
self.split_mode = split_mode
self.target_size = target_size

if any(x <= 0 for x in self.grid_size):
msg = "grid_size must contain positive integers."
raise ValueError(msg)
if split_mode not in ["uniform", "random"]:
msg = "split_mode must be 'uniform' or 'random'."
raise ValueError(msg)

self.read_fn = read_fn

if reference_data is None:
warn("No reference data provided for Mosaic. This transform will act as a no-op.")
self.reference_data: List[Any] = []
elif isinstance(reference_data, (types.GeneratorType, Iterable)) and not isinstance(reference_data, str):
if isinstance(reference_data, Sequence) and len(reference_data) < self.grid_size[0] * self.grid_size[1] - 1:
msg = "Not enough reference data to fill the mosaic grid."
raise ValueError(msg)
self.reference_data = reference_data # type: ignore[assignment]
else:
msg = "reference_data must be a list, tuple, generator, or None."
raise TypeError(msg)

self.preprocessing_mode = preprocessing_mode

def apply(
self,
img: np.ndarray,
mix_data: List[ReferenceImage],
tiles: np.ndarray,
preprocessing_pipeline: Compose,
**params: Any,
) -> np.ndarray:
transformed_img = preprocessing_pipeline(image=img)["image"]
return self.apply_to_image_or_mask(transformed_img, "image", mix_data, tiles)

def apply_to_mask(self, mask: np.ndarray, mix_data: List[ReferenceImage], *args: Any, **params: Any) -> np.ndarray:
msg = "Mosaic does not support keypoints yet"
raise NotImplementedError(msg)
Comment on lines +320 to +322
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (typo): The apply_to_mask method in the Mosaic class raises NotImplementedError with a message indicating that Mosaic does not support keypoints, which seems to be a copy-paste error from the apply_to_keypoints method. The error message should correctly reflect that it's the mask processing that is not implemented.

Suggested change
def apply_to_mask(self, mask: np.ndarray, mix_data: List[ReferenceImage], *args: Any, **params: Any) -> np.ndarray:
msg = "Mosaic does not support keypoints yet"
raise NotImplementedError(msg)
def apply_to_mask(self, mask: np.ndarray, mix_data: List[ReferenceImage], *args: Any, **params: Any) -> np.ndarray:
msg = "Mosaic does not support mask processing yet"
raise NotImplementedError(msg)


def apply_to_bbox(self, bbox: Any, *args: Any, **params: Any) -> Any:
msg = "Mosaic does not support bbox yet"
raise NotImplementedError(msg)

def apply_to_global_label(
self, label: np.ndarray, mix_data: List[ReferenceImage], *args: Any, **params: Any
) -> np.ndarray:
msg = "Mosaic does not support global label yet"
raise NotImplementedError(msg)

def sample_reference_data(self) -> List[Dict[str, Any]]:
total_tiles = self.grid_size[0] * self.grid_size[1]
sampled_reference_data: List[Any] = []

if isinstance(self.reference_data, Sequence) and len(self.reference_data):
# Select data without replacement if there are more items than needed, else with replacement
return choice(self.reference_data, self.grid_size[0] * self.grid_size[1] - 1, replace=False)

if isinstance(self.reference_data, Iterator):
# Get the necessary number of elements from the iterator
sampled_reference_data = []

try:
for _ in range(total_tiles - 1):
next_element = next(self.reference_data, None)
if next_element is None:
# The iterator doesn't have enough elements
warn("Reference data iterator has insufficient data to fill the mosaic grid.", RuntimeWarning)
# Reset mix_data as we can't fulfill the required grid tiles
return []
sampled_reference_data.append(next_element)
except StopIteration:
# This block is in case the iterator was shorter than expected and ran out before total_tiles
warn("Reference data iterator was exhausted before filling all tiles.", RuntimeWarning)
# Reset mix_data as we can't fulfill the required grid tiles
sampled_reference_data = []

return sampled_reference_data

def get_params(self) -> Dict[str, Any]:
sampled_reference_data = self.sample_reference_data()

tiles = split_uniform_grid(self.target_size, self.grid_size)

shuffle(tiles)

mix_data = []

for idx, tile in enumerate(tiles[:-1]): # last position in shuffled tiles is for target
element = sampled_reference_data[idx]
processed_element = self.read_fn(element)
# Extract the tile dimensions
tile_height, tile_width = tile[2] - tile[0], tile[3] - tile[1]
# Preprocess the element based on the tile size
processed_element = self.preprocess_element(processed_element, tile_height, tile_width)
mix_data.append(processed_element)

last_tile = tiles[-1]
last_tile_width = last_tile[3] - last_tile[1]
last_tile_height = last_tile[2] - last_tile[0]
preprocessing_pipeline = self.get_preprocessing_pipeline(last_tile_height, last_tile_width)

return {"mix_data": sampled_reference_data, "preprocessing_pipepline": preprocessing_pipeline}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (typo): There's a typo in the dictionary key preprocessing_pipepline; it should be preprocessing_pipeline. This typo could lead to runtime errors when accessing the returned dictionary values.

Suggested change
return {"mix_data": sampled_reference_data, "preprocessing_pipepline": preprocessing_pipeline}
return {"mix_data": sampled_reference_data, "preprocessing_pipeline": preprocessing_pipeline}


def get_preprocessing_pipeline(self, tile_height: int, tile_width: int) -> Compose:
if self.preprocessing_mode == "resize":
return Compose([Resize(height=tile_height, width=tile_width)])
if self.preprocessing_mode == "longest_max_size_pad":
return Compose(
[
LongestMaxSize(max_size=max(tile_height, tile_width)),
PadIfNeeded(min_height=tile_height, min_width=tile_width),
]
)
if self.preprocessing_mode == "smallest_max_size_crop":
return Compose(
[
SmallestMaxSize(max_size=min(tile_height, tile_width)),
RandomSizedBBoxSafeCrop(height=tile_height, width=tile_width),
]
)
if self.preprocessing_mode == "random_sized_bbox_safe_crop":
return Compose(
[
RandomSizedBBoxSafeCrop(height=tile_height, width=tile_width),
LongestMaxSize(max_size=max(tile_height, tile_width)),
PadIfNeeded(min_height=tile_height, min_width=tile_width),
]
)

raise ValueError(f"Unknown preprocessing_mode {self.preprocessing_mode}")

def preprocess_element(self, element: ReferenceImage, tile_height: int, tile_width: int) -> ReferenceImage:
preprocessing_pipeline = self.get_preprocessing_pipeline(tile_height, tile_width)

# Apply the preprocess pipeline to the image, mask, and other elements
return cast(ReferenceImage, preprocessing_pipeline(**element))

def apply_to_image_or_mask(
self, data: np.ndarray, data_key: Literal["image", "mask"], mix_data: List[ReferenceImage], tiles: np.ndarray
) -> np.ndarray:
"""Apply transformations to an image or mask based on mixed data and tile positioning.

Args:
data (np.ndarray): The original image or mask data.
data_key (str): The key in the processed elements dictionary that corresponds to the data
('image' or 'mask').
mix_data (List[ReferenceImage]): List of processed elements (dictionaries containing 'image', 'mask').
tiles (List[Tuple[int, int, int, int]]): List of tile coordinates.

Returns:
np.ndarray: The new image or mask after applying the mosaic transformations.
"""
new_data = np.empty(self.target_size)
for element, tile in zip(mix_data, tiles):
if data_key in element:
y_min, x_min, y_max, x_max = tile
element_data = element[data_key]
new_data[y_min:y_max, x_min:x_max] = element_data

last_tile = tiles[-1]
y_min, x_min, y_max, x_max = last_tile

new_data[y_min:y_max, x_min:x_max] = data

return new_data

def get_transform_init_args_names(self) -> Tuple[str, ...]:
return "reference_data", "grid_size", "split_mode", "preprocessing_mode"
51 changes: 19 additions & 32 deletions albumentations/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,49 +139,36 @@ class Compose(BaseCompose):
def __init__(
self,
transforms: TransformsSeqType,
bbox_params: Optional[Union[Dict[str, Any], "BboxParams"]] = None,
keypoint_params: Optional[Union[Dict[str, Any], "KeypointParams"]] = None,
bbox_params: Optional[Union[Dict[str, Any], BboxParams]] = None,
keypoint_params: Optional[Union[Dict[str, Any], KeypointParams]] = None,
additional_targets: Optional[Dict[str, str]] = None,
p: float = 1.0,
is_check_shapes: bool = True,
):
super().__init__(transforms, p)
self.processors: Dict[str, Union["BboxProcessor", "KeypointsProcessor"]] = {}
self.additional_targets = additional_targets or {}
self.is_check_shapes = is_check_shapes
self.bbbox_params = bbox_params
self.keypoint_params = keypoint_params

self.processors: Dict[str, Union[BboxProcessor, KeypointsProcessor]] = {}
if bbox_params:
if isinstance(bbox_params, dict):
b_params = BboxParams(**bbox_params)
elif isinstance(bbox_params, BboxParams):
b_params = bbox_params
else:
msg = "unknown format of bbox_params, please use `dict` or `BboxParams`"
raise ValueError(msg)
self.processors["bboxes"] = BboxProcessor(b_params, additional_targets)

self.processors["bboxes"] = self.init_processor(BboxProcessor, bbox_params)
if keypoint_params:
if isinstance(keypoint_params, dict):
k_params = KeypointParams(**keypoint_params)
elif isinstance(keypoint_params, KeypointParams):
k_params = keypoint_params
else:
msg = "unknown format of keypoint_params, please use `dict` or `KeypointParams`"
raise ValueError(msg)
self.processors["keypoints"] = KeypointsProcessor(k_params, additional_targets)

if additional_targets is None:
additional_targets = {}
self.processors["keypoints"] = self.init_processor(KeypointsProcessor, keypoint_params)

self.additional_targets = additional_targets
def init_processor(
self, processor_class: Any, params: Union[Dict[str, Any], "BboxParams", "KeypointParams"]
) -> Any:
if isinstance(params, dict):
return processor_class(**params, additional_targets=self.additional_targets)

for proc in self.processors.values():
proc.ensure_transforms_valid(self.transforms)
return processor_class(params, additional_targets=self.additional_targets)

self.add_targets(additional_targets)

self.is_check_args = True
self._disable_check_args_for_transforms(self.transforms)

self.is_check_shapes = is_check_shapes
def update_transforms_with_params(self) -> None:
for transform in self.transforms:
if isinstance(transform, BasicTransform):
transform.update_with_external_params(self.bbbox_params, self.keypoint_params)

@staticmethod
def _disable_check_args_for_transforms(transforms: TransformsSeqType) -> None:
Expand Down
Loading
Loading