Skip to content

Commit

Permalink
Fixed Distortions
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Oct 23, 2024
1 parent 09ddef0 commit 9fae992
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 13 deletions.
50 changes: 38 additions & 12 deletions albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

from albumentations import random_utils
from albumentations.augmentations.utils import angle_2pi_range, handle_empty_array
from albumentations.core.bbox_utils import bboxes_from_masks, denormalize_bboxes, masks_from_bboxes, normalize_bboxes
from albumentations.core.bbox_utils import bboxes_from_masks, denormalize_bboxes, normalize_bboxes
from albumentations.core.types import (
MONO_CHANNEL_DIMENSIONS,
NUM_BBOXES_COLUMNS_IN_ALBUMENTATIONS,
NUM_KEYPOINTS_COLUMNS_IN_ALBUMENTATIONS,
NUM_MULTI_CHANNEL_DIMENSIONS,
REFLECT_BORDER_MODES,
Expand Down Expand Up @@ -1555,21 +1555,47 @@ def distortion_bboxes(
map_x: np.ndarray,
map_y: np.ndarray,
image_shape: tuple[int, int],
border_mode: int,
) -> np.ndarray:
result = bboxes.copy()
height, width = image_shape[:2]

masks = np.transpose(masks_from_bboxes(bboxes, image_shape), (1, 2, 0))
transformed_masks = cv2.remap(masks, map_x, map_y, cv2.INTER_NEAREST, borderMode=border_mode, borderValue=0)
# Convert bboxes to corner points: (N, 4) -> (N*2, 2)
corners = np.vstack(
[
bboxes[:, [0, 1]], # top-left corners
bboxes[:, [2, 3]], # bottom-right corners
],
)

if transformed_masks.ndim == MONO_CHANNEL_DIMENSIONS:
transformed_masks = np.expand_dims(transformed_masks, axis=0)
else:
transformed_masks = np.transpose(transformed_masks, (2, 0, 1))
# Transform corners using distortion_keypoints
transformed_corners = distortion_keypoints(
np.column_stack([corners, np.zeros(len(corners)), np.zeros(len(corners))]), # add dummy angle and scale
map_x,
map_y,
image_shape,
)

result[:, :4] = bboxes_from_masks(transformed_masks)
# Reshape back to bboxes format: (N*2, 2) -> (N, 4)
num_boxes = len(bboxes)
transformed_corners = transformed_corners[:, :2].reshape(2, num_boxes, 2)

return result
# Get min/max coordinates to form new bounding boxes
mins = transformed_corners[0] # top-left corners
maxs = transformed_corners[1] # bottom-right corners

new_bboxes = np.column_stack(
[
np.minimum(mins[:, 0], maxs[:, 0]), # x_min
np.minimum(mins[:, 1], maxs[:, 1]), # y_min
np.maximum(mins[:, 0], maxs[:, 0]), # x_max
np.maximum(mins[:, 1], maxs[:, 1]), # y_max
],
)

return (
np.column_stack([new_bboxes, bboxes[:, 4:]])
if bboxes.shape[1] > NUM_BBOXES_COLUMNS_IN_ALBUMENTATIONS
else new_bboxes
)


def generate_displacement_fields(
Expand Down
2 changes: 1 addition & 1 deletion albumentations/augmentations/geometric/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def apply_to_mask(self, mask: np.ndarray, map_x: np.ndarray, map_y: np.ndarray,
def apply_to_bboxes(self, bboxes: np.ndarray, map_x: np.ndarray, map_y: np.ndarray, **params: Any) -> np.ndarray:
image_shape = params["shape"][:2]
bboxes_denorm = denormalize_bboxes(bboxes, image_shape)
bboxes_returned = fgeometric.distortion_bboxes(bboxes_denorm, map_x, map_y, image_shape, self.border_mode)
bboxes_returned = fgeometric.distortion_bboxes(bboxes_denorm, map_x, map_y, image_shape)
return normalize_bboxes(bboxes_returned, image_shape)

def apply_to_keypoints(
Expand Down
84 changes: 84 additions & 0 deletions tests/test_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,3 +1539,87 @@ def test_random_resized_crop():
labels = [1,2,3,4]
res = transform(image=np.zeros((500,500,3), dtype='uint8'), bboxes=boxes, label=labels)
assert len(res['bboxes']) == len(res['label'])


@pytest.mark.parametrize(
["bboxes", "map_x", "map_y", "image_shape", "expected"],
[
# Test case 1: Identity mapping (no distortion)
(
np.array([[10, 20, 30, 40]]), # single bbox
np.tile(np.arange(100), (100, 1)), # identity map_x
np.tile(np.arange(100).reshape(-1, 1), (1, 100)), # identity map_y
(100, 100),
np.array([[10, 20, 30, 40]]),
),
# Test case 2: Simple translation
(
np.array([[10, 20, 30, 40]]),
np.tile(np.arange(100) - 5, (100, 1)), # shift right by 5
np.tile(np.arange(100).reshape(-1, 1) - 5, (1, 100)), # shift down by 5
(100, 100),
np.array([[15, 25, 35, 45]]),
),
# Test case 3: Multiple bboxes with additional attributes
(
np.array([
[10, 20, 30, 40, 1], # bbox with class label
[50, 60, 70, 80, 2],
]),
np.tile(np.arange(100), (100, 1)), # identity map_x
np.tile(np.arange(100).reshape(-1, 1), (1, 100)), # identity map_y
(100, 100),
np.array([
[10, 20, 30, 40, 1],
[50, 60, 70, 80, 2],
]),
),
# Test case 4: Boundary conditions
(
np.array([[0, 0, 10, 10]]), # bbox at image corner
np.tile(np.arange(100), (100, 1)),
np.tile(np.arange(100).reshape(-1, 1), (1, 100)),
(100, 100),
np.array([[0, 0, 10, 10]]),
),
# Test case 5: Empty array
(
np.zeros((0, 4)), # empty bbox array
np.tile(np.arange(100), (100, 1)),
np.tile(np.arange(100).reshape(-1, 1), (1, 100)),
(100, 100),
np.zeros((0, 4)),
),
],
)
def test_distortion_bboxes(bboxes, map_x, map_y, image_shape, expected):
result = fgeometric.distortion_bboxes(bboxes, map_x, map_y, image_shape)
np.testing.assert_array_almost_equal(result, expected)


def test_distortion_bboxes_complex_distortion():
# Test with a more complex distortion pattern
bboxes = np.array([[25, 25, 75, 75]]) # center box
image_shape = (100, 100)

# Create a radial distortion pattern
y, x = np.mgrid[0:100, 0:100]
c_x, c_y = 50, 50 # distortion center
r = np.sqrt((x - c_x)**2 + (y - c_y)**2)
factor = 1 + r/100 # increasing distortion with radius

map_x = x + (x - c_x) / factor
map_y = y + (y - c_y) / factor

result = fgeometric.distortion_bboxes(bboxes, map_x, map_y, image_shape)

# Check that the result is different from input but still valid
assert not np.array_equal(result, bboxes)
assert np.all(result >= 0)
assert np.all(result[:, [0, 2]] <= image_shape[1]) # x coordinates
assert np.all(result[:, [1, 3]] <= image_shape[0]) # y coordinates
assert np.all(result[:, [0, 1]] <= result[:, [2, 3]]) # min <= max

0 comments on commit 9fae992

Please sign in to comment.