diff --git a/README.md b/README.md index c403bf998..eaf4d4620 100644 --- a/README.md +++ b/README.md @@ -236,6 +236,7 @@ Spatial-level transforms will simultaneously change both an input image as well | [CropNonEmptyMaskIfExists](https://explore.albumentations.ai/transform/CropNonEmptyMaskIfExists) | ✓ | ✓ | ✓ | ✓ | | [D4](https://explore.albumentations.ai/transform/D4) | ✓ | ✓ | ✓ | ✓ | | [ElasticTransform](https://explore.albumentations.ai/transform/ElasticTransform) | ✓ | ✓ | ✓ | ✓ | +| [FrequencyMasking](https://explore.albumentations.ai/transform/FrequencyMasking) | ✓ | ✓ | ✓ | ✓ | | [GridDistortion](https://explore.albumentations.ai/transform/GridDistortion) | ✓ | ✓ | ✓ | ✓ | | [GridDropout](https://explore.albumentations.ai/transform/GridDropout) | ✓ | ✓ | ✓ | ✓ | | [GridElasticDeform](https://explore.albumentations.ai/transform/GridElasticDeform) | ✓ | ✓ | ✓ | ✓ | diff --git a/albumentations/augmentations/spectrogram/transform.py b/albumentations/augmentations/spectrogram/transform.py index f433fc7c4..5ae7f8a44 100644 --- a/albumentations/augmentations/spectrogram/transform.py +++ b/albumentations/augmentations/spectrogram/transform.py @@ -12,6 +12,7 @@ __all__ = [ "TimeReverse", "TimeMasking", + "FrequencyMasking", ] @@ -82,7 +83,7 @@ class TimeMasking(XYMasking): Args: time_mask_param (int): Maximum possible length of the mask in the time domain. - Must be a positive integer. Length of the mask is uniformly sampled from [0, time_mask_param). + Must be a positive integer. Length of the mask is uniformly sampled from (0, time_mask_param). p (float): probability of applying the transform. Default: 0.5. Targets: @@ -116,7 +117,7 @@ class TimeMasking(XYMasking): _targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS) class InitSchema(BaseTransformInitSchema): - time_mask_param: int = Field(ge=0) + time_mask_param: int = Field(gt=0) def __init__( self, @@ -144,3 +145,81 @@ def __init__( def get_transform_init_args_names(self) -> tuple[str, ...]: return ("time_mask_param",) + + +class FrequencyMasking(XYMasking): + """Apply masking to a spectrogram in the frequency domain. + + This transform masks random segments along the frequency axis of a spectrogram, + implementing the frequency masking technique proposed in the SpecAugment paper. + Frequency masking helps in training models to be robust against frequency variations + and missing spectral information in audio signals. + + This is a specialized version of XYMasking configured for frequency masking only. + For more advanced use cases (e.g., multiple masks, time masking, or custom + fill values), consider using XYMasking directly. + + Args: + freq_mask_param (int): Maximum possible length of the mask in the frequency domain. + Must be a positive integer. Length of the mask is uniformly sampled from (0, freq_mask_param). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + + Number of channels: + Any + + Note: + This transform is implemented as a subset of XYMasking with fixed parameters: + - Single vertical mask (num_masks_y=1) + - No horizontal masks (num_masks_x=0) + - Zero fill value + - Random mask length up to freq_mask_param + + For more flexibility, including: + - Multiple masks + - Custom fill values + - Time masking + - Combined time-frequency masking + Consider using albumentations.XYMasking directly. + + References: + - SpecAugment paper: https://arxiv.org/abs/1904.08779 + - Original implementation: https://pytorch.org/audio/stable/transforms.html#freqmask + """ + + _targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS) + + class InitSchema(BaseTransformInitSchema): + freq_mask_param: int = Field(gt=0) + + def __init__( + self, + freq_mask_param: int = 30, + p: float = 0.5, + always_apply: bool | None = None, + ): + warn( + "FrequencyMasking is a specialized version of XYMasking. " + "For more flexibility (multiple masks, custom fill values, time masking), " + "consider using XYMasking directly from albumentations.XYMasking.", + UserWarning, + stacklevel=2, + ) + super().__init__( + p=p, + always_apply=always_apply, + fill_value=0, + mask_fill_value=None, + mask_y_length=(0, freq_mask_param), + num_masks_x=0, + num_masks_y=1, + ) + self.freq_mask_param = freq_mask_param + + def get_transform_init_args_names(self) -> tuple[str, ...]: + return ("freq_mask_param",) diff --git a/tests/aug_definitions.py b/tests/aug_definitions.py index f95ef37c0..7323a8285 100644 --- a/tests/aug_definitions.py +++ b/tests/aug_definitions.py @@ -381,4 +381,5 @@ [A.ShotNoise, {"scale_range": (0.1, 0.3)}], [A.TimeReverse, {}], [A.TimeMasking, {"time_mask_param": 10}], + [A.FrequencyMasking, {"freq_mask_param": 10}], ] diff --git a/tests/test_core.py b/tests/test_core.py index a5a719cf7..e48971181 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1332,6 +1332,7 @@ def test_masks_as_target(augmentation_cls, params, masks): A.MaskDropout, A.XYMasking, A.TimeMasking, + A.FrequencyMasking, }, ), ) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3f358d767..af6b5e261 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1249,7 +1249,6 @@ def test_coarse_dropout_invalid_input(params): "reference_images": [SQUARE_UINT8_IMAGE + 1], "read_fn": lambda x: x, }, - A.TimeMasking: {"time_mask_param": 10}, }, except_augmentations={ A.RandomCropNearBBox,