Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add frequecny masking #2123

Merged
merged 4 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ Spatial-level transforms will simultaneously change both an input image as well
| [CropNonEmptyMaskIfExists](https://explore.albumentations.ai/transform/CropNonEmptyMaskIfExists) | ✓ | ✓ | ✓ | ✓ |
| [D4](https://explore.albumentations.ai/transform/D4) | ✓ | ✓ | ✓ | ✓ |
| [ElasticTransform](https://explore.albumentations.ai/transform/ElasticTransform) | ✓ | ✓ | ✓ | ✓ |
| [FrequencyMasking](https://explore.albumentations.ai/transform/FrequencyMasking) | ✓ | ✓ | ✓ | ✓ |
| [GridDistortion](https://explore.albumentations.ai/transform/GridDistortion) | ✓ | ✓ | ✓ | ✓ |
| [GridDropout](https://explore.albumentations.ai/transform/GridDropout) | ✓ | ✓ | ✓ | ✓ |
| [GridElasticDeform](https://explore.albumentations.ai/transform/GridElasticDeform) | ✓ | ✓ | ✓ | ✓ |
Expand Down
83 changes: 81 additions & 2 deletions albumentations/augmentations/spectrogram/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
__all__ = [
"TimeReverse",
"TimeMasking",
"FrequencyMasking",
]


Expand Down Expand Up @@ -82,7 +83,7 @@ class TimeMasking(XYMasking):

Args:
time_mask_param (int): Maximum possible length of the mask in the time domain.
Must be a positive integer. Length of the mask is uniformly sampled from [0, time_mask_param).
Must be a positive integer. Length of the mask is uniformly sampled from (0, time_mask_param).
p (float): probability of applying the transform. Default: 0.5.

Targets:
Expand Down Expand Up @@ -116,7 +117,7 @@ class TimeMasking(XYMasking):
_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)

class InitSchema(BaseTransformInitSchema):
time_mask_param: int = Field(ge=0)
time_mask_param: int = Field(gt=0)

def __init__(
self,
Expand Down Expand Up @@ -144,3 +145,81 @@ def __init__(

def get_transform_init_args_names(self) -> tuple[str, ...]:
return ("time_mask_param",)


class FrequencyMasking(XYMasking):
ternaus marked this conversation as resolved.
Show resolved Hide resolved
"""Apply masking to a spectrogram in the frequency domain.

This transform masks random segments along the frequency axis of a spectrogram,
implementing the frequency masking technique proposed in the SpecAugment paper.
Frequency masking helps in training models to be robust against frequency variations
and missing spectral information in audio signals.

This is a specialized version of XYMasking configured for frequency masking only.
For more advanced use cases (e.g., multiple masks, time masking, or custom
fill values), consider using XYMasking directly.

Args:
freq_mask_param (int): Maximum possible length of the mask in the frequency domain.
Must be a positive integer. Length of the mask is uniformly sampled from (0, freq_mask_param).
p (float): probability of applying the transform. Default: 0.5.

Targets:
image, mask, bboxes, keypoints

Image types:
uint8, float32

Number of channels:
Any

Note:
This transform is implemented as a subset of XYMasking with fixed parameters:
- Single vertical mask (num_masks_y=1)
- No horizontal masks (num_masks_x=0)
- Zero fill value
- Random mask length up to freq_mask_param

For more flexibility, including:
- Multiple masks
- Custom fill values
- Time masking
- Combined time-frequency masking
Consider using albumentations.XYMasking directly.

References:
- SpecAugment paper: https://arxiv.org/abs/1904.08779
- Original implementation: https://pytorch.org/audio/stable/transforms.html#freqmask
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)

class InitSchema(BaseTransformInitSchema):
freq_mask_param: int = Field(gt=0)

def __init__(
self,
freq_mask_param: int = 30,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"FrequencyMasking is a specialized version of XYMasking. "
"For more flexibility (multiple masks, custom fill values, time masking), "
"consider using XYMasking directly from albumentations.XYMasking.",
UserWarning,
stacklevel=2,
)
super().__init__(
p=p,
always_apply=always_apply,
fill_value=0,
mask_fill_value=None,
mask_y_length=(0, freq_mask_param),
num_masks_x=0,
num_masks_y=1,
)
self.freq_mask_param = freq_mask_param

def get_transform_init_args_names(self) -> tuple[str, ...]:
return ("freq_mask_param",)
1 change: 1 addition & 0 deletions tests/aug_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,4 +381,5 @@
[A.ShotNoise, {"scale_range": (0.1, 0.3)}],
[A.TimeReverse, {}],
[A.TimeMasking, {"time_mask_param": 10}],
[A.FrequencyMasking, {"freq_mask_param": 10}],
ternaus marked this conversation as resolved.
Show resolved Hide resolved
]
1 change: 1 addition & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,7 @@ def test_masks_as_target(augmentation_cls, params, masks):
A.MaskDropout,
A.XYMasking,
A.TimeMasking,
A.FrequencyMasking,
},
),
)
Expand Down
1 change: 0 additions & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,6 @@ def test_coarse_dropout_invalid_input(params):
"reference_images": [SQUARE_UINT8_IMAGE + 1],
"read_fn": lambda x: x,
},
A.TimeMasking: {"time_mask_param": 10},
},
except_augmentations={
A.RandomCropNearBBox,
Expand Down