Skip to content

Commit

Permalink
Add time masking (#2122)
Browse files Browse the repository at this point in the history
* Empty-Commit

* Added TimeMasking

* Added TimeMasking

* Updated tests
  • Loading branch information
ternaus authored Nov 8, 2024
1 parent 2a3a18a commit d538952
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ jobs:
- name: Install PyTorch
run: |
if [ "${{ matrix.operating-system }}" = "macos-latest" ]; then
uv pip install --system torch==2.4.1 torchvision==0.19.1
uv pip install --system torch==2.5.1 torchvision==0.20.1
else
uv pip install --system torch==2.4.1+cpu torchvision==0.19.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu
uv pip install --system torch==2.5.1+cpu torchvision==0.20.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu
fi
shell: bash

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ Spatial-level transforms will simultaneously change both an input image as well
| [SafeRotate](https://explore.albumentations.ai/transform/SafeRotate) |||||
| [ShiftScaleRotate](https://explore.albumentations.ai/transform/ShiftScaleRotate) |||||
| [SmallestMaxSize](https://explore.albumentations.ai/transform/SmallestMaxSize) |||||
| [TimeMasking](https://explore.albumentations.ai/transform/TimeMasking) |||||
| [TimeReverse](https://explore.albumentations.ai/transform/TimeReverse) |||||
| [Transpose](https://explore.albumentations.ai/transform/Transpose) |||||
| [VerticalFlip](https://explore.albumentations.ai/transform/VerticalFlip) |||||
Expand Down
8 changes: 7 additions & 1 deletion albumentations/augmentations/dropout/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,16 @@ def __init__(
self.mask_fill_value = mask_fill_value

def apply(self, img: np.ndarray, holes: np.ndarray, seed: int, **params: Any) -> np.ndarray:
if holes.size == 0:
return img
if self.fill_value in {"inpaint_telea", "inpaint_ns"}:
num_channels = get_num_channels(img)
if num_channels not in {1, 3}:
raise ValueError("Inpainting works only for 1 or 3 channel images")
return cutout(img, holes, self.fill_value, np.random.default_rng(seed))

def apply_to_mask(self, mask: np.ndarray, holes: np.ndarray, seed: int, **params: Any) -> np.ndarray:
if self.mask_fill_value is None:
if self.mask_fill_value is None or holes.size == 0:
return mask
return cutout(mask, holes, self.mask_fill_value, np.random.default_rng(seed))

Expand All @@ -71,6 +73,8 @@ def apply_to_bboxes(
holes: np.ndarray,
**params: Any,
) -> np.ndarray:
if holes.size == 0:
return bboxes
processor = cast(BboxProcessor, self.get_processor("bboxes"))
if processor is None:
return bboxes
Expand All @@ -96,6 +100,8 @@ def apply_to_keypoints(
holes: np.ndarray,
**params: Any,
) -> np.ndarray:
if holes.size == 0:
return keypoints
processor = cast(KeypointsProcessor, self.get_processor("keypoints"))

if processor is None or not processor.params.remove_invisible:
Expand Down
1 change: 1 addition & 0 deletions albumentations/augmentations/dropout/xy_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def get_params_dependent_on_data(
masks_y = self.generate_masks(self.num_masks_y, image_shape, self.mask_y_length, axis="y")

holes = np.array(masks_x + masks_y)

return {"holes": holes, "seed": self.random_generator.integers(0, 2**32 - 1)}

def generate_mask_size(self, mask_length: tuple[int, int]) -> int:
Expand Down
90 changes: 90 additions & 0 deletions albumentations/augmentations/spectrogram/transform.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from __future__ import annotations

from warnings import warn

from pydantic import Field

from albumentations.augmentations.dropout.xy_masking import XYMasking
from albumentations.augmentations.geometric.transforms import HorizontalFlip
from albumentations.core.transforms_interface import BaseTransformInitSchema
from albumentations.core.types import Targets

__all__ = [
"TimeReverse",
"TimeMasking",
]


Expand Down Expand Up @@ -53,4 +59,88 @@ def __init__(
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"TimeReverse is an alias for HorizontalFlip transform. "
"Consider using HorizontalFlip directly from albumentations.HorizontalFlip. ",
UserWarning,
stacklevel=2,
)
super().__init__(p=p, always_apply=always_apply)


class TimeMasking(XYMasking):
"""Apply masking to a spectrogram in the time domain.
This transform masks random segments along the time axis of a spectrogram,
implementing the time masking technique proposed in the SpecAugment paper.
Time masking helps in training models to be robust against temporal variations
and missing information in audio signals.
This is a specialized version of XYMasking configured for time masking only.
For more advanced use cases (e.g., multiple masks, frequency masking, or custom
fill values), consider using XYMasking directly.
Args:
time_mask_param (int): Maximum possible length of the mask in the time domain.
Must be a positive integer. Length of the mask is uniformly sampled from [0, time_mask_param).
p (float): probability of applying the transform. Default: 0.5.
Targets:
image, mask, bboxes, keypoints
Image types:
uint8, float32
Number of channels:
Any
Note:
This transform is implemented as a subset of XYMasking with fixed parameters:
- Single horizontal mask (num_masks_x=1)
- No vertical masks (num_masks_y=0)
- Zero fill value
- Random mask length up to time_mask_param
For more flexibility, including:
- Multiple masks
- Custom fill values
- Frequency masking
- Combined time-frequency masking
Consider using albumentations.XYMasking directly.
References:
- SpecAugment paper: https://arxiv.org/abs/1904.08779
- Original implementation: https://pytorch.org/audio/stable/transforms.html#timemask
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)

class InitSchema(BaseTransformInitSchema):
time_mask_param: int = Field(ge=0)

def __init__(
self,
time_mask_param: int = 40,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"TimeMasking is a specialized version of XYMasking. "
"For more flexibility (multiple masks, custom fill values, frequency masking), "
"consider using XYMasking directly from albumentations.XYMasking.",
UserWarning,
stacklevel=2,
)
super().__init__(
p=p,
always_apply=always_apply,
fill_value=0,
mask_fill_value=None,
mask_x_length=(0, time_mask_param),
num_masks_x=1,
num_masks_y=0,
)
self.time_mask_param = time_mask_param

def get_transform_init_args_names(self) -> tuple[str, ...]:
return ("time_mask_param",)
4 changes: 2 additions & 2 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,8 @@ def get_params_dependent_on_data(

return result

def get_transform_init_args_names(self) -> tuple[str, str]:
return "snow_point_range", "brightness_coeff"
def get_transform_init_args_names(self) -> tuple[str, ...]:
return "snow_point_range", "brightness_coeff", "method"


class RandomGravel(ImageOnlyTransform):
Expand Down
1 change: 1 addition & 0 deletions tests/aug_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,4 +380,5 @@
[A.GridElasticDeform, {"num_grid_xy": (10, 10), "magnitude": 10}],
[A.ShotNoise, {"scale_range": (0.1, 0.3)}],
[A.TimeReverse, {}],
[A.TimeMasking, {"time_mask_param": 10}],
]
3 changes: 2 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,7 @@ def test_masks_as_target(augmentation_cls, params, masks):
A.FromFloat,
A.MaskDropout,
A.XYMasking,
A.TimeMasking,
},
),
)
Expand All @@ -1343,7 +1344,7 @@ def test_mask_interpolation(augmentation_cls, params, interpolation):
image = SQUARE_UINT8_IMAGE
mask = image.copy()

aug = A.Compose([augmentation_cls(p=1, interpolation=interpolation, mask_interpolation=interpolation, **params)])
aug = A.Compose([augmentation_cls(p=1, interpolation=interpolation, mask_interpolation=interpolation, seed=42,**params)])

transformed = aug(image=image, mask=mask)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,7 @@ def test_coarse_dropout_invalid_input(params):
"reference_images": [SQUARE_UINT8_IMAGE + 1],
"read_fn": lambda x: x,
},
A.TimeMasking: {"time_mask_param": 10},
},
except_augmentations={
A.RandomCropNearBBox,
Expand Down Expand Up @@ -1313,6 +1314,7 @@ def test_change_image(augmentation_cls, params):
A.FancyPCA: {"alpha": 1},
A.GridElasticDeform: {"num_grid_xy": (10, 10), "magnitude": 10},
A.RGBShift: {"r_shift_limit": (10, 10), "g_shift_limit": (10, 10), "b_shift_limit": (10, 10)},
A.TimeMasking: {"time_mask_param": 10},
},
except_augmentations={
A.Crop,
Expand Down

0 comments on commit d538952

Please sign in to comment.