diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8a90d8b13..7a0df188d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,9 +46,9 @@ jobs: - name: Install PyTorch run: | if [ "${{ matrix.operating-system }}" = "macos-latest" ]; then - uv pip install --system torch==2.4.1 torchvision==0.19.1 + uv pip install --system torch==2.5.1 torchvision==0.20.1 else - uv pip install --system torch==2.4.1+cpu torchvision==0.19.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu + uv pip install --system torch==2.5.1+cpu torchvision==0.20.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu fi shell: bash diff --git a/README.md b/README.md index 0b361d2b8..c403bf998 100644 --- a/README.md +++ b/README.md @@ -264,6 +264,7 @@ Spatial-level transforms will simultaneously change both an input image as well | [SafeRotate](https://explore.albumentations.ai/transform/SafeRotate) | ✓ | ✓ | ✓ | ✓ | | [ShiftScaleRotate](https://explore.albumentations.ai/transform/ShiftScaleRotate) | ✓ | ✓ | ✓ | ✓ | | [SmallestMaxSize](https://explore.albumentations.ai/transform/SmallestMaxSize) | ✓ | ✓ | ✓ | ✓ | +| [TimeMasking](https://explore.albumentations.ai/transform/TimeMasking) | ✓ | ✓ | ✓ | ✓ | | [TimeReverse](https://explore.albumentations.ai/transform/TimeReverse) | ✓ | ✓ | ✓ | ✓ | | [Transpose](https://explore.albumentations.ai/transform/Transpose) | ✓ | ✓ | ✓ | ✓ | | [VerticalFlip](https://explore.albumentations.ai/transform/VerticalFlip) | ✓ | ✓ | ✓ | ✓ | diff --git a/albumentations/augmentations/dropout/transforms.py b/albumentations/augmentations/dropout/transforms.py index cff12ce3a..0cbf02ed4 100644 --- a/albumentations/augmentations/dropout/transforms.py +++ b/albumentations/augmentations/dropout/transforms.py @@ -54,6 +54,8 @@ def __init__( self.mask_fill_value = mask_fill_value def apply(self, img: np.ndarray, holes: np.ndarray, seed: int, **params: Any) -> np.ndarray: + if holes.size == 0: + return img if self.fill_value in {"inpaint_telea", "inpaint_ns"}: num_channels = get_num_channels(img) if num_channels not in {1, 3}: @@ -61,7 +63,7 @@ def apply(self, img: np.ndarray, holes: np.ndarray, seed: int, **params: Any) -> return cutout(img, holes, self.fill_value, np.random.default_rng(seed)) def apply_to_mask(self, mask: np.ndarray, holes: np.ndarray, seed: int, **params: Any) -> np.ndarray: - if self.mask_fill_value is None: + if self.mask_fill_value is None or holes.size == 0: return mask return cutout(mask, holes, self.mask_fill_value, np.random.default_rng(seed)) @@ -71,6 +73,8 @@ def apply_to_bboxes( holes: np.ndarray, **params: Any, ) -> np.ndarray: + if holes.size == 0: + return bboxes processor = cast(BboxProcessor, self.get_processor("bboxes")) if processor is None: return bboxes @@ -96,6 +100,8 @@ def apply_to_keypoints( holes: np.ndarray, **params: Any, ) -> np.ndarray: + if holes.size == 0: + return keypoints processor = cast(KeypointsProcessor, self.get_processor("keypoints")) if processor is None or not processor.params.remove_invisible: diff --git a/albumentations/augmentations/dropout/xy_masking.py b/albumentations/augmentations/dropout/xy_masking.py index 17b8d7065..9acf4941e 100644 --- a/albumentations/augmentations/dropout/xy_masking.py +++ b/albumentations/augmentations/dropout/xy_masking.py @@ -123,6 +123,7 @@ def get_params_dependent_on_data( masks_y = self.generate_masks(self.num_masks_y, image_shape, self.mask_y_length, axis="y") holes = np.array(masks_x + masks_y) + return {"holes": holes, "seed": self.random_generator.integers(0, 2**32 - 1)} def generate_mask_size(self, mask_length: tuple[int, int]) -> int: diff --git a/albumentations/augmentations/spectrogram/transform.py b/albumentations/augmentations/spectrogram/transform.py index 0f5ec43d7..f433fc7c4 100644 --- a/albumentations/augmentations/spectrogram/transform.py +++ b/albumentations/augmentations/spectrogram/transform.py @@ -1,11 +1,17 @@ from __future__ import annotations +from warnings import warn + +from pydantic import Field + +from albumentations.augmentations.dropout.xy_masking import XYMasking from albumentations.augmentations.geometric.transforms import HorizontalFlip from albumentations.core.transforms_interface import BaseTransformInitSchema from albumentations.core.types import Targets __all__ = [ "TimeReverse", + "TimeMasking", ] @@ -53,4 +59,88 @@ def __init__( p: float = 0.5, always_apply: bool | None = None, ): + warn( + "TimeReverse is an alias for HorizontalFlip transform. " + "Consider using HorizontalFlip directly from albumentations.HorizontalFlip. ", + UserWarning, + stacklevel=2, + ) super().__init__(p=p, always_apply=always_apply) + + +class TimeMasking(XYMasking): + """Apply masking to a spectrogram in the time domain. + + This transform masks random segments along the time axis of a spectrogram, + implementing the time masking technique proposed in the SpecAugment paper. + Time masking helps in training models to be robust against temporal variations + and missing information in audio signals. + + This is a specialized version of XYMasking configured for time masking only. + For more advanced use cases (e.g., multiple masks, frequency masking, or custom + fill values), consider using XYMasking directly. + + Args: + time_mask_param (int): Maximum possible length of the mask in the time domain. + Must be a positive integer. Length of the mask is uniformly sampled from [0, time_mask_param). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + + Number of channels: + Any + + Note: + This transform is implemented as a subset of XYMasking with fixed parameters: + - Single horizontal mask (num_masks_x=1) + - No vertical masks (num_masks_y=0) + - Zero fill value + - Random mask length up to time_mask_param + + For more flexibility, including: + - Multiple masks + - Custom fill values + - Frequency masking + - Combined time-frequency masking + Consider using albumentations.XYMasking directly. + + References: + - SpecAugment paper: https://arxiv.org/abs/1904.08779 + - Original implementation: https://pytorch.org/audio/stable/transforms.html#timemask + """ + + _targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS) + + class InitSchema(BaseTransformInitSchema): + time_mask_param: int = Field(ge=0) + + def __init__( + self, + time_mask_param: int = 40, + p: float = 0.5, + always_apply: bool | None = None, + ): + warn( + "TimeMasking is a specialized version of XYMasking. " + "For more flexibility (multiple masks, custom fill values, frequency masking), " + "consider using XYMasking directly from albumentations.XYMasking.", + UserWarning, + stacklevel=2, + ) + super().__init__( + p=p, + always_apply=always_apply, + fill_value=0, + mask_fill_value=None, + mask_x_length=(0, time_mask_param), + num_masks_x=1, + num_masks_y=0, + ) + self.time_mask_param = time_mask_param + + def get_transform_init_args_names(self) -> tuple[str, ...]: + return ("time_mask_param",) diff --git a/albumentations/augmentations/transforms.py b/albumentations/augmentations/transforms.py index 64539546f..e5c8c0f7f 100644 --- a/albumentations/augmentations/transforms.py +++ b/albumentations/augmentations/transforms.py @@ -552,8 +552,8 @@ def get_params_dependent_on_data( return result - def get_transform_init_args_names(self) -> tuple[str, str]: - return "snow_point_range", "brightness_coeff" + def get_transform_init_args_names(self) -> tuple[str, ...]: + return "snow_point_range", "brightness_coeff", "method" class RandomGravel(ImageOnlyTransform): diff --git a/tests/aug_definitions.py b/tests/aug_definitions.py index a10715a1d..f95ef37c0 100644 --- a/tests/aug_definitions.py +++ b/tests/aug_definitions.py @@ -380,4 +380,5 @@ [A.GridElasticDeform, {"num_grid_xy": (10, 10), "magnitude": 10}], [A.ShotNoise, {"scale_range": (0.1, 0.3)}], [A.TimeReverse, {}], + [A.TimeMasking, {"time_mask_param": 10}], ] diff --git a/tests/test_core.py b/tests/test_core.py index 1b005b988..a5a719cf7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1331,6 +1331,7 @@ def test_masks_as_target(augmentation_cls, params, masks): A.FromFloat, A.MaskDropout, A.XYMasking, + A.TimeMasking, }, ), ) @@ -1343,7 +1344,7 @@ def test_mask_interpolation(augmentation_cls, params, interpolation): image = SQUARE_UINT8_IMAGE mask = image.copy() - aug = A.Compose([augmentation_cls(p=1, interpolation=interpolation, mask_interpolation=interpolation, **params)]) + aug = A.Compose([augmentation_cls(p=1, interpolation=interpolation, mask_interpolation=interpolation, seed=42,**params)]) transformed = aug(image=image, mask=mask) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index dc2ec6a78..3f358d767 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1249,6 +1249,7 @@ def test_coarse_dropout_invalid_input(params): "reference_images": [SQUARE_UINT8_IMAGE + 1], "read_fn": lambda x: x, }, + A.TimeMasking: {"time_mask_param": 10}, }, except_augmentations={ A.RandomCropNearBBox, @@ -1313,6 +1314,7 @@ def test_change_image(augmentation_cls, params): A.FancyPCA: {"alpha": 1}, A.GridElasticDeform: {"num_grid_xy": (10, 10), "magnitude": 10}, A.RGBShift: {"r_shift_limit": (10, 10), "g_shift_limit": (10, 10), "b_shift_limit": (10, 10)}, + A.TimeMasking: {"time_mask_param": 10}, }, except_augmentations={ A.Crop,