Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update lut #2000

Merged
merged 4 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 43 additions & 73 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
multiply_add,
normalize_per_image,
preserve_channel_dim,
sz_lut,
to_float,
uint8_io,
)
Expand Down Expand Up @@ -89,54 +90,14 @@
]


def _shift_hsv_uint8(
img: np.ndarray,
hue_shift: np.ndarray,
sat_shift: np.ndarray,
val_shift: np.ndarray,
) -> np.ndarray:
dtype = img.dtype
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
hue, sat, val = cv2.split(img)

if hue_shift != 0:
lut_hue = np.arange(0, 256, dtype=np.int16)
lut_hue = np.mod(lut_hue + hue_shift, 180).astype(dtype)
hue = cv2.LUT(hue, lut_hue)

sat = add(sat, sat_shift)
val = add(val, val_shift)

img = cv2.merge((hue, sat, val)).astype(dtype)
return cv2.cvtColor(img, cv2.COLOR_HSV2RGB)


def _shift_hsv_non_uint8(
img: np.ndarray,
hue_shift: np.ndarray,
sat_shift: np.ndarray,
val_shift: np.ndarray,
) -> np.ndarray:
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
hue, sat, val = cv2.split(img)

if hue_shift != 0:
hue = cv2.add(hue, hue_shift)
hue = np.mod(hue, 360) # OpenCV fails with negative values

sat = add(sat, sat_shift)
val = add(val, val_shift)

img = cv2.merge((hue, sat, val))
return cv2.cvtColor(img, cv2.COLOR_HSV2RGB)


@uint8_io
@preserve_channel_dim
def shift_hsv(img: np.ndarray, hue_shift: np.ndarray, sat_shift: np.ndarray, val_shift: np.ndarray) -> np.ndarray:
def shift_hsv(img: np.ndarray, hue_shift: float, sat_shift: float, val_shift: float) -> np.ndarray:
if hue_shift == 0 and sat_shift == 0 and val_shift == 0:
return img

is_gray = is_grayscale_image(img)

if is_gray:
if hue_shift != 0 or sat_shift != 0:
hue_shift = 0
Expand All @@ -148,16 +109,28 @@ def shift_hsv(img: np.ndarray, hue_shift: np.ndarray, sat_shift: np.ndarray, val
)
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

if img.dtype == np.uint8:
img = _shift_hsv_uint8(img, hue_shift, sat_shift, val_shift)
else:
img = _shift_hsv_non_uint8(img, hue_shift, sat_shift, val_shift)
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
hue, sat, val = cv2.split(img)

if hue_shift != 0:
lut_hue = np.arange(0, 256, dtype=np.int16)
lut_hue = np.mod(lut_hue + hue_shift, 180).astype(np.uint8)
hue = sz_lut(hue, lut_hue, inplace=False)

if sat_shift != 0:
sat = add(sat, sat_shift, inplace=False)

if val_shift != 0:
val = add(val, val_shift, inplace=False)

img = cv2.merge((hue, sat, val))
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)

return cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) if is_gray else img


@clipped
def solarize(img: np.ndarray, threshold: int = 128) -> np.ndarray:
def solarize(img: np.ndarray, threshold: int) -> np.ndarray:
"""Invert all pixel values above a threshold.

Args:
Expand All @@ -175,16 +148,15 @@ def solarize(img: np.ndarray, threshold: int = 128) -> np.ndarray:
lut = [(i if i < threshold else max_val - i) for i in range(int(max_val) + 1)]

prev_shape = img.shape
img = cv2.LUT(img, np.array(lut, dtype=dtype))
img = sz_lut(img, np.array(lut, dtype=dtype), inplace=False)

if len(prev_shape) != len(img.shape):
img = np.expand_dims(img, -1)
return np.expand_dims(img, -1)
return img

ternaus marked this conversation as resolved.
Show resolved Hide resolved
result_img = img.copy()
cond = img >= threshold
result_img[cond] = max_val - result_img[cond]
return result_img
img[cond] = max_val - img[cond]
return img


@uint8_io
Expand Down Expand Up @@ -213,7 +185,7 @@ def posterize(img: np.ndarray, bits: Literal[0, 1, 2, 3, 4, 5, 6, 7, 8]) -> np.n
mask = ~np.uint8(2 ** (8 - bits_array) - 1)
lut &= mask

return cv2.LUT(img, lut)
return sz_lut(img, lut, inplace=False)

result_img = np.empty_like(img)
for i, channel_bits in enumerate(bits_array):
Expand All @@ -226,7 +198,7 @@ def posterize(img: np.ndarray, bits: Literal[0, 1, 2, 3, 4, 5, 6, 7, 8]) -> np.n
mask = ~np.uint8(2 ** (8 - channel_bits) - 1)
lut &= mask

result_img[..., i] = cv2.LUT(img[..., i], lut)
result_img[..., i] = sz_lut(img[..., i], lut, inplace=True)

return result_img

Expand All @@ -248,7 +220,7 @@ def _equalize_pil(img: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray
lut[i] = min(n // step, 255)
n += histogram[i]

return cv2.LUT(img, np.array(lut))
return sz_lut(img, np.array(lut), inplace=True)


def _equalize_cv(img: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
Expand Down Expand Up @@ -276,7 +248,7 @@ def _equalize_cv(img: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
_sum += histogram[idx]
lut[idx] = clip(round(_sum * scale), np.uint8)

return cv2.LUT(img, lut)
return sz_lut(img, lut, inplace=True)


def _check_preconditions(img: np.ndarray, mask: np.ndarray | None, by_channels: bool) -> None:
Expand Down Expand Up @@ -393,10 +365,12 @@ def evaluate_bez(t: np.ndarray, low_y: float | np.ndarray, high_y: float | np.nd

if np.isscalar(low_y) and np.isscalar(high_y):
lut = clip(np.rint(evaluate_bez(t, low_y, high_y)), np.uint8)
return cv2.LUT(img, lut)
return sz_lut(img, lut, inplace=False)
if isinstance(low_y, np.ndarray) and isinstance(high_y, np.ndarray):
luts = clip(np.rint(evaluate_bez(t[:, np.newaxis], low_y, high_y).T), np.uint8)
return cv2.merge([cv2.LUT(img[:, :, i], luts[i]) for i in range(num_channels)])
return cv2.merge(
[sz_lut(img[:, :, i], np.ascontiguousarray(luts[i]), inplace=False) for i in range(num_channels)],
)

raise TypeError(
f"low_y and high_y must both be of type float or np.ndarray. Got {type(low_y)} and {type(high_y)}",
Expand Down Expand Up @@ -614,7 +588,7 @@ def add_snow_texture(img: np.ndarray, snow_point: float, brightness_coeff: float
snow_layer = (np.dstack([snow_texture] * 3) * max_value * snow_point).astype(np.float32)

# Blend snow with original image
img_with_snow = cv2.addWeighted(img_hsv, 1, snow_layer, 1, 0)
img_with_snow = cv2.add(img_hsv, snow_layer)

# Add a slight blue tint to simulate cool snow color
blue_tint = np.full_like(img_with_snow, (0.6, 0.75, 1)) # Slight blue in HSV
Expand Down Expand Up @@ -814,6 +788,7 @@ def add_sun_flare_overlay(
num_times = src_radius // 10
alpha = np.linspace(0.0, 1, num=num_times)
rad = np.linspace(1, src_radius, num=num_times)

for i in range(num_times):
cv2.circle(overlay, point, int(rad[i]), src_color, -1)
alp = alpha[num_times - i - 1] * alpha[num_times - i - 1] * alpha[num_times - i - 1]
Expand Down Expand Up @@ -1003,7 +978,8 @@ def channel_shuffle(img: np.ndarray, channels_shuffled: np.ndarray) -> np.ndarra
def gamma_transform(img: np.ndarray, gamma: float) -> np.ndarray:
if img.dtype == np.uint8:
table = (np.arange(0, 256.0 / 255, 1.0 / 255) ** gamma) * 255
return cv2.LUT(img, table.astype(np.uint8))
return sz_lut(img, table.astype(np.uint8), inplace=False)

return np.power(img, gamma)


Expand All @@ -1019,7 +995,7 @@ def brightness_contrast_adjust(
else:
value = beta * np.mean(img)

return multiply_add(img, alpha, value)
return multiply_add(img, alpha, value, inplace=False)


@float32_io
Expand Down Expand Up @@ -1427,7 +1403,7 @@ def adjust_brightness_torchvision(img: np.ndarray, factor: np.ndarray) -> np:
if factor == 1:
return img

return multiply(img, factor)
return multiply(img, factor, inplace=False)


@preserve_channel_dim
Expand All @@ -1442,33 +1418,27 @@ def adjust_contrast_torchvision(img: np.ndarray, factor: float) -> np.ndarray:
mean = int(mean + 0.5)
return np.full_like(img, mean, dtype=img.dtype)

return multiply_add(img, factor, mean * (1 - factor))
return multiply_add(img, factor, mean * (1 - factor), inplace=False)


@clipped
@preserve_channel_dim
def adjust_saturation_torchvision(img: np.ndarray, factor: float, gamma: float = 0) -> np.ndarray:
if factor == 1:
return img

if is_grayscale_image(img):
if factor == 1 or is_grayscale_image(img):
return img

gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)

if factor == 0:
return gray

return cv2.addWeighted(img, factor, gray, 1 - factor, gamma=gamma)
return gray if factor == 0 else cv2.addWeighted(img, factor, gray, 1 - factor, gamma=gamma)


def _adjust_hue_torchvision_uint8(img: np.ndarray, factor: float) -> np.ndarray:
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)

lut = np.arange(0, 256, dtype=np.int16)
lut = np.mod(lut + 180 * factor, 180).astype(np.uint8)
img[..., 0] = cv2.LUT(img[..., 0], lut)
img[..., 0] = sz_lut(img[..., 0], lut, inplace=False)

return cv2.cvtColor(img, cv2.COLOR_HSV2RGB)

Expand Down
2 changes: 1 addition & 1 deletion albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,7 +2251,7 @@ def __init__(

def apply(self, img: np.ndarray, shift: np.ndarray, **params: Any) -> np.ndarray:
non_rgb_error(img)
return albucore.add_vector(img, shift)
return albucore.add_vector(img, shift, inplace=False)

def get_params(self) -> dict[str, Any]:
return {
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"PyYAML",
"typing-extensions>=4.9.0; python_version<'3.10'",
"pydantic>=2.7.0",
"albucore==0.0.17",
"albucore==0.0.18",
"eval-type-backport"
]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ def test_random_tone_curve(image):
result_float_value = F.move_tone_curve(image, low_y, high_y)
result_array_value = F.move_tone_curve(image, np.array([low_y] * num_channels), np.array([high_y] * num_channels))

assert np.array_equal(result_float_value, result_array_value)
ternaus marked this conversation as resolved.
Show resolved Hide resolved
np.testing.assert_allclose(result_float_value, result_array_value)

assert result_float_value.dtype == image.dtype
assert result_float_value.shape == image.shape
Expand Down
56 changes: 2 additions & 54 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import cv2
import numpy as np
import pytest
from albucore.functions import to_float
from albucore.utils import clip
from albucore.functions import to_float, clip, from_float

from torchvision import transforms as torch_transforms

import albumentations as A
Expand Down Expand Up @@ -653,58 +653,6 @@ def test_color_jitter_float_uint8_equal(brightness, contrast, saturation, hue):
assert _max <= 2, f"Max: {_max}"


@pytest.mark.parametrize(["hue", "sat", "val"], [[13, 17, 23], [14, 18, 24], [131, 143, 151], [132, 144, 152]])
def test_hue_saturation_value_float_uint8_equal(hue, sat, val):
img = SQUARE_UINT8_IMAGE

for i in range(2):
sign = 1 if i == 0 else -1
for i in range(4):
if i == 0:
_hue = hue * sign
_sat = 0
_val = 0
elif i == 1:
_hue = 0
_sat = sat * sign
_val = 0
elif i == 2:
_hue = 0
_sat = 0
_val = val * sign
else:
_hue = hue * sign
_sat = sat * sign
_val = val * sign

t1 = A.Compose(
[
A.HueSaturationValue(
hue_shift_limit=[_hue, _hue],
sat_shift_limit=[_sat, _sat],
val_shift_limit=[_val, _val],
p=1,
),
],
)
t2 = A.Compose(
[
A.HueSaturationValue(
hue_shift_limit=[_hue / 180 * 360, _hue / 180 * 360],
sat_shift_limit=[_sat / 255, _sat / 255],
val_shift_limit=[_val / 255, _val / 255],
p=1,
),
],
)

res1 = t1(image=img)["image"]
res2 = (t2(image=img.astype(np.float32) / 255.0)["image"] * 255).astype(np.uint8)

_max = np.abs(res1.astype(np.int32) - res2).max()
assert _max <= 10, f"Max value: {_max}"


def test_perspective_keep_size():
h, w = 100, 100
img = np.zeros([h, w, 3], dtype=np.uint8)
Expand Down
Loading