Skip to content

Commit

Permalink
Add ChromaticAberration (#1589)
Browse files Browse the repository at this point in the history
* add ChromaticAberration

* add pre-commit hook changes

* remove unused import

* refactor random functions [#1591]

---------

Co-authored-by: Vladimir Iglovikov <[email protected]>
  • Loading branch information
mrsmrynk and ternaus authored Mar 18, 2024
1 parent 23444c9 commit f52c8db
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ Pixel-level transforms will change just an input image and will leave any additi
- [CLAHE](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.CLAHE)
- [ChannelDropout](https://albumentations.ai/docs/api_reference/augmentations/dropout/channel_dropout/#albumentations.augmentations.dropout.channel_dropout.ChannelDropout)
- [ChannelShuffle](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ChannelShuffle)
- [ChromaticAberration](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ChromaticAberration)
- [ColorJitter](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ColorJitter)
- [Defocus](https://albumentations.ai/docs/api_reference/augmentations/blur/transforms/#albumentations.augmentations.blur.transforms.Defocus)
- [Downscale](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Downscale)
Expand Down
78 changes: 77 additions & 1 deletion albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
preserve_channel_dim,
preserve_shape,
)
from albumentations.core.types import ColorType, ImageMode, ScalarType, SpatterMode, image_modes
from albumentations.core.types import (
ColorType,
ImageMode,
ScalarType,
SpatterMode,
image_modes,
)

__all__ = [
"add_fog",
Expand Down Expand Up @@ -62,6 +68,7 @@
"unsharp_mask",
"MAX_VALUES_BY_DTYPE",
"split_uniform_grid",
"chromatic_aberration",
]

TWO = 2
Expand Down Expand Up @@ -1422,3 +1429,72 @@ def split_uniform_grid(image_shape: Tuple[int, int], grid: Tuple[int, int]) -> n
]

return np.array(tiles)


def chromatic_aberration(
img: np.ndarray,
primary_distortion_red: float,
secondary_distortion_red: float,
primary_distortion_blue: float,
secondary_distortion_blue: float,
interpolation: int,
) -> np.ndarray:
non_rgb_warning(img)

height, width = img.shape[:2]

# Build camera matrix
camera_mat = np.eye(3, dtype=np.float32)
camera_mat[0, 0] = width
camera_mat[1, 1] = height
camera_mat[0, 2] = width / 2.0
camera_mat[1, 2] = height / 2.0

# Build distortion coefficients
distortion_coeffs_red = np.array([primary_distortion_red, secondary_distortion_red, 0, 0], dtype=np.float32)
distortion_coeffs_blue = np.array([primary_distortion_blue, secondary_distortion_blue, 0, 0], dtype=np.float32)

# Distort the red and blue channels
red_distorted = _distort_channel(
img[..., 0],
camera_mat,
distortion_coeffs_red,
height,
width,
interpolation,
)
blue_distorted = _distort_channel(
img[..., 2],
camera_mat,
distortion_coeffs_blue,
height,
width,
interpolation,
)

return np.dstack([red_distorted, img[..., 1], blue_distorted])


def _distort_channel(
channel: np.ndarray,
camera_mat: np.ndarray,
distortion_coeffs: np.ndarray,
height: int,
width: int,
interpolation: int,
) -> np.ndarray:
map_x, map_y = cv2.initUndistortRectifyMap(
cameraMatrix=camera_mat,
distCoeffs=distortion_coeffs,
R=None,
newCameraMatrix=camera_mat,
size=(width, height),
m1type=cv2.CV_32FC1,
)
return cv2.remap(
channel,
map_x,
map_y,
interpolation=interpolation,
borderMode=cv2.BORDER_REPLICATE,
)
122 changes: 122 additions & 0 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from albumentations.core.transforms_interface import DualTransform, ImageOnlyTransform, Interpolation, NoOp, to_tuple
from albumentations.core.types import (
BoxInternalType,
ChromaticAberrationMode,
ImageMode,
KeypointInternalType,
ScaleFloatType,
Expand Down Expand Up @@ -77,6 +78,7 @@
"UnsharpMask",
"PixelDropout",
"Spatter",
"ChromaticAberration",
]

HUNDRED = 100
Expand Down Expand Up @@ -2744,3 +2746,123 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A

def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str, str]:
return "mean", "std", "gauss_sigma", "intensity", "cutout_threshold", "mode", "color"


class ChromaticAberration(ImageOnlyTransform):
"""Add lateral chromatic aberration by distorting the red and blue channels of the input image.
Args:
primary_distortion_limit: range of the primary radial distortion coefficient.
If primary_distortion_limit is a single float value, the range will be
(-primary_distortion_limit, primary_distortion_limit).
Controls the distortion in the center of the image (positive values result in pincushion distortion,
negative values result in barrel distortion).
Default: 0.02.
secondary_distortion_limit: range of the secondary radial distortion coefficient.
If secondary_distortion_limit is a single float value, the range will be
(-secondary_distortion_limit, secondary_distortion_limit).
Controls the distortion in the corners of the image (positive values result in pincushion distortion,
negative values result in barrel distortion).
Default: 0.05.
mode: type of color fringing.
Supported modes are 'green_purple', 'red_blue' and 'random'.
'random' will choose one of the modes 'green_purple' or 'red_blue' randomly.
Default: 'green_purple'.
interpolation: flag that is used to specify the interpolation algorithm. Should be one of:
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_LINEAR.
p: probability of applying the transform.
Default: 0.5.
Targets:
image
Image types:
uint8, float32
"""

def __init__(
self,
primary_distortion_limit: ScaleFloatType = 0.02,
secondary_distortion_limit: ScaleFloatType = 0.05,
mode: ChromaticAberrationMode = "green_purple",
interpolation: int = cv2.INTER_LINEAR,
always_apply: bool = False,
p: float = 0.5,
):
super().__init__(always_apply, p)
self.primary_distortion_limit = to_tuple(primary_distortion_limit)
self.secondary_distortion_limit = to_tuple(secondary_distortion_limit)
self.mode = self._validate_mode(mode)
self.interpolation = interpolation

@staticmethod
def _validate_mode(
mode: ChromaticAberrationMode,
) -> ChromaticAberrationMode:
valid_modes = ["green_purple", "red_blue", "random"]
if mode not in valid_modes:
msg = f"Unsupported mode: {mode}. Supported modes are 'green_purple', 'red_blue', 'random'."
raise ValueError(msg)
return mode

def apply(
self,
img: np.ndarray,
primary_distortion_red: float = -0.02,
secondary_distortion_red: float = -0.05,
primary_distortion_blue: float = -0.02,
secondary_distortion_blue: float = -0.05,
**params: Any,
) -> np.ndarray:
return F.chromatic_aberration(
img,
primary_distortion_red,
secondary_distortion_red,
primary_distortion_blue,
secondary_distortion_blue,
cast(int, self.interpolation),
)

def get_params(self) -> Dict[str, float]:
primary_distortion_red = random_utils.uniform(*self.primary_distortion_limit)
secondary_distortion_red = random_utils.uniform(*self.secondary_distortion_limit)
primary_distortion_blue = random_utils.uniform(*self.primary_distortion_limit)
secondary_distortion_blue = random_utils.uniform(*self.secondary_distortion_limit)

secondary_distortion_red = self._match_sign(primary_distortion_red, secondary_distortion_red)
secondary_distortion_blue = self._match_sign(primary_distortion_blue, secondary_distortion_blue)

if self.mode == "green_purple":
# distortion coefficients of the red and blue channels have the same sign
primary_distortion_blue = self._match_sign(primary_distortion_red, primary_distortion_blue)
secondary_distortion_blue = self._match_sign(secondary_distortion_red, secondary_distortion_blue)
if self.mode == "red_blue":
# distortion coefficients of the red and blue channels have the opposite sign
primary_distortion_blue = self._unmatch_sign(primary_distortion_red, primary_distortion_blue)
secondary_distortion_blue = self._unmatch_sign(secondary_distortion_red, secondary_distortion_blue)

return {
"primary_distortion_red": primary_distortion_red,
"secondary_distortion_red": secondary_distortion_red,
"primary_distortion_blue": primary_distortion_blue,
"secondary_distortion_blue": secondary_distortion_blue,
}

@staticmethod
def _match_sign(a: float, b: float) -> float:
# Match the sign of b to a
if (a < 0 < b) or (a > 0 > b):
b = -b
return b

@staticmethod
def _unmatch_sign(a: float, b: float) -> float:
# Unmatch the sign of b to a
if (a < 0 and b < 0) or (a > 0 and b > 0):
b = -b
return b

def get_transform_init_args_names(self) -> Tuple[str, str, str, str]:
return "primary_distortion_limit", "secondary_distortion_limit", "mode", "interpolation"
1 change: 1 addition & 0 deletions albumentations/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@


SpatterMode = Literal["rain", "mud"]
ChromaticAberrationMode = Literal["green_purple", "red_blue", "random"]


class ReferenceImage(TypedDict):
Expand Down
5 changes: 5 additions & 0 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def test_augmentations_wont_change_float_input(augmentation_cls, params, float_i
A.UnsharpMask,
A.RandomCropFromBorders,
A.Spatter,
A.ChromaticAberration,
},
),
)
Expand Down Expand Up @@ -546,6 +547,7 @@ def test_mask_fill_value(augmentation_cls, params):
A.FancyPCA,
A.PixelDistributionAdaptation,
A.Spatter,
A.ChromaticAberration,
},
),
)
Expand Down Expand Up @@ -624,6 +626,7 @@ def test_multichannel_image_augmentations(augmentation_cls, params):
A.RandomToneCurve,
A.PixelDistributionAdaptation,
A.Spatter,
A.ChromaticAberration,
},
),
)
Expand Down Expand Up @@ -693,6 +696,7 @@ def test_float_multichannel_image_augmentations(augmentation_cls, params):
A.HistogramMatching,
A.PixelDistributionAdaptation,
A.Spatter,
A.ChromaticAberration,
},
),
)
Expand Down Expand Up @@ -767,6 +771,7 @@ def test_multichannel_image_augmentations_diff_channels(augmentation_cls, params
A.HistogramMatching,
A.PixelDistributionAdaptation,
A.Spatter,
A.ChromaticAberration,
},
),
)
Expand Down
9 changes: 9 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,15 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m
mode="mud",
),
],
[
A.ChromaticAberration,
dict(
primary_distortion_limit=0.02,
secondary_distortion_limit=0.05,
mode="green_purple",
interpolation=cv2.INTER_LINEAR,
),
],
[A.Defocus, {"radius": (5, 7), "alias_blur": (0.2, 0.6)}],
[A.ZoomBlur, {"max_factor": (1.56, 1.7), "step_factor": (0.02, 0.04)}],
[
Expand Down
10 changes: 10 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,7 @@ def check_center(vector):
A.RandomSunFlare(),
A.RandomShadow(),
A.Spatter(),
A.ChromaticAberration(),
],
)
@pytest.mark.parametrize("img_channels", [1, 6])
Expand All @@ -1317,6 +1318,15 @@ def test_spatter_incorrect_mode(image):
assert str(exc_info.value).startswith(message)


def test_chromatic_aberration_incorrect_mode(image):
unsupported_mode = "unsupported"
with pytest.raises(ValueError) as exc_info:
A.ChromaticAberration(mode=unsupported_mode)

message = f"Unsupported mode: {unsupported_mode}. Supported modes are 'green_purple', 'red_blue', 'random'."
assert str(exc_info.value).startswith(message)


@pytest.mark.parametrize(
"unsupported_color,mode,message",
[
Expand Down

0 comments on commit f52c8db

Please sign in to comment.