Skip to content

Commit

Permalink
Add frequecny masking (#2123)
Browse files Browse the repository at this point in the history
* Empty-Commit

* Added FrequencyMasking

* Fix in tests

* Fix
  • Loading branch information
ternaus authored Nov 8, 2024
1 parent d538952 commit ec9c442
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ Spatial-level transforms will simultaneously change both an input image as well
| [CropNonEmptyMaskIfExists](https://explore.albumentations.ai/transform/CropNonEmptyMaskIfExists) |||||
| [D4](https://explore.albumentations.ai/transform/D4) |||||
| [ElasticTransform](https://explore.albumentations.ai/transform/ElasticTransform) |||||
| [FrequencyMasking](https://explore.albumentations.ai/transform/FrequencyMasking) |||||
| [GridDistortion](https://explore.albumentations.ai/transform/GridDistortion) |||||
| [GridDropout](https://explore.albumentations.ai/transform/GridDropout) |||||
| [GridElasticDeform](https://explore.albumentations.ai/transform/GridElasticDeform) |||||
Expand Down
83 changes: 81 additions & 2 deletions albumentations/augmentations/spectrogram/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
__all__ = [
"TimeReverse",
"TimeMasking",
"FrequencyMasking",
]


Expand Down Expand Up @@ -82,7 +83,7 @@ class TimeMasking(XYMasking):
Args:
time_mask_param (int): Maximum possible length of the mask in the time domain.
Must be a positive integer. Length of the mask is uniformly sampled from [0, time_mask_param).
Must be a positive integer. Length of the mask is uniformly sampled from (0, time_mask_param).
p (float): probability of applying the transform. Default: 0.5.
Targets:
Expand Down Expand Up @@ -116,7 +117,7 @@ class TimeMasking(XYMasking):
_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)

class InitSchema(BaseTransformInitSchema):
time_mask_param: int = Field(ge=0)
time_mask_param: int = Field(gt=0)

def __init__(
self,
Expand Down Expand Up @@ -144,3 +145,81 @@ def __init__(

def get_transform_init_args_names(self) -> tuple[str, ...]:
return ("time_mask_param",)


class FrequencyMasking(XYMasking):
"""Apply masking to a spectrogram in the frequency domain.
This transform masks random segments along the frequency axis of a spectrogram,
implementing the frequency masking technique proposed in the SpecAugment paper.
Frequency masking helps in training models to be robust against frequency variations
and missing spectral information in audio signals.
This is a specialized version of XYMasking configured for frequency masking only.
For more advanced use cases (e.g., multiple masks, time masking, or custom
fill values), consider using XYMasking directly.
Args:
freq_mask_param (int): Maximum possible length of the mask in the frequency domain.
Must be a positive integer. Length of the mask is uniformly sampled from (0, freq_mask_param).
p (float): probability of applying the transform. Default: 0.5.
Targets:
image, mask, bboxes, keypoints
Image types:
uint8, float32
Number of channels:
Any
Note:
This transform is implemented as a subset of XYMasking with fixed parameters:
- Single vertical mask (num_masks_y=1)
- No horizontal masks (num_masks_x=0)
- Zero fill value
- Random mask length up to freq_mask_param
For more flexibility, including:
- Multiple masks
- Custom fill values
- Time masking
- Combined time-frequency masking
Consider using albumentations.XYMasking directly.
References:
- SpecAugment paper: https://arxiv.org/abs/1904.08779
- Original implementation: https://pytorch.org/audio/stable/transforms.html#freqmask
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)

class InitSchema(BaseTransformInitSchema):
freq_mask_param: int = Field(gt=0)

def __init__(
self,
freq_mask_param: int = 30,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"FrequencyMasking is a specialized version of XYMasking. "
"For more flexibility (multiple masks, custom fill values, time masking), "
"consider using XYMasking directly from albumentations.XYMasking.",
UserWarning,
stacklevel=2,
)
super().__init__(
p=p,
always_apply=always_apply,
fill_value=0,
mask_fill_value=None,
mask_y_length=(0, freq_mask_param),
num_masks_x=0,
num_masks_y=1,
)
self.freq_mask_param = freq_mask_param

def get_transform_init_args_names(self) -> tuple[str, ...]:
return ("freq_mask_param",)
1 change: 1 addition & 0 deletions tests/aug_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,4 +381,5 @@
[A.ShotNoise, {"scale_range": (0.1, 0.3)}],
[A.TimeReverse, {}],
[A.TimeMasking, {"time_mask_param": 10}],
[A.FrequencyMasking, {"freq_mask_param": 10}],
]
1 change: 1 addition & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,7 @@ def test_masks_as_target(augmentation_cls, params, masks):
A.MaskDropout,
A.XYMasking,
A.TimeMasking,
A.FrequencyMasking,
},
),
)
Expand Down
1 change: 0 additions & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,6 @@ def test_coarse_dropout_invalid_input(params):
"reference_images": [SQUARE_UINT8_IMAGE + 1],
"read_fn": lambda x: x,
},
A.TimeMasking: {"time_mask_param": 10},
},
except_augmentations={
A.RandomCropNearBBox,
Expand Down

0 comments on commit ec9c442

Please sign in to comment.