Skip to content

Commit

Permalink
Fix augmentations/geometric
Browse files Browse the repository at this point in the history
  • Loading branch information
Dipet committed Mar 14, 2024
1 parent 502993a commit 6e70334
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 21 deletions.
40 changes: 28 additions & 12 deletions albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
"bboxes_transpose",
"vflip",
"hflip",
"vflip_cv2",
"hflip_cv2",
"transpose",
"keypoints_flip",
Expand Down Expand Up @@ -99,18 +98,19 @@ def bboxes_rot90(bboxes: BoxesArray, factor: int, rows: int, cols: int) -> Boxes
"""
if factor not in {0, 1, 2, 3}:
raise ValueError("Parameter n must be in set {0, 1, 2, 3}")
msg = "Parameter n must be in set {0, 1, 2, 3}"
raise ValueError(msg)

if not factor or not len(bboxes):
return bboxes

if factor == 1:
bboxes[:, [0, 2]] = 1 - bboxes[:, [0, 2]]
bboxes = bboxes[:, [1, 2, 3, 0]]
elif factor == 2:
elif factor == TWO:
bboxes = 1 - bboxes
bboxes = bboxes[:, [2, 3, 0, 1]]
elif factor == 3:
elif factor == THREE:
bboxes[:, [1, 3]] = 1 - bboxes[:, [1, 3]]
bboxes = bboxes[:, [3, 0, 1, 2]]
return bboxes
Expand All @@ -136,17 +136,18 @@ def keypoints_rot90(keypoints: KeypointsArray, factor: int, rows: int, cols: int
"""
if factor not in {0, 1, 2, 3}:
raise ValueError("Parameter n must be in set {0, 1, 2, 3}")
msg = "Parameter n must be in set {0, 1, 2, 3}"
raise ValueError(msg)

if factor == 1:
keypoints[..., 2] -= math.pi / 2
keypoints[..., 0] = cols - 1 - keypoints[..., 0]
keypoints[..., [0, 1]] = keypoints[..., [1, 0]]
elif factor == 2:
elif factor == TWO:
keypoints[..., 2] -= math.pi
keypoints[..., 0] = cols - 1 - keypoints[..., 0]
keypoints[..., 1] = rows - 1 - keypoints[..., 1]
elif factor == 3:
elif factor == THREE:
keypoints[..., 2] += math.pi / 2
keypoints[..., 1] = rows - 1 - keypoints[..., 1]
keypoints[..., [0, 1]] = keypoints[..., [1, 0]]
Expand Down Expand Up @@ -300,6 +301,24 @@ def keypoints_shift_scale_rotate(
def bboxes_shift_scale_rotate(
bboxes: BoxesArray, angle: int, scale_: int, dx: int, dy: int, rotate_method: str, rows: int, cols: int, **kwargs
) -> BoxesArray:
"""Rotates, shifts and scales a bounding box. Rotation is made by angle degrees,
scaling is made by scale factor and shifting is made by dx and dy.
Args:
bbox (tuple): A bounding box `(x_min, y_min, x_max, y_max)`.
angle (int): Angle of rotation in degrees.
scale (int): Scale factor.
dx (int): Shift along x-axis in pixel units.
dy (int): Shift along y-axis in pixel units.
rotate_method(str): Rotation method used. Should be one of: "largest_box", "ellipse".
Default: "largest_box".
rows (int): Image rows.
cols (int): Image cols.
Returns:
A bounding box `(x_min, y_min, x_max, y_max)`.
"""
if not len(bboxes):
return bboxes
center = (cols / 2, rows / 2)
Expand Down Expand Up @@ -954,10 +973,6 @@ def hflip(img: np.ndarray) -> np.ndarray:
return np.ascontiguousarray(img[:, ::-1, ...])


def vflip_cv2(img: np.ndarray) -> np.ndarray:
return cv2.flip(img, 0)


def hflip_cv2(img: np.ndarray) -> np.ndarray:
return cv2.flip(img, 1)

Expand Down Expand Up @@ -1056,7 +1071,8 @@ def bboxes_transpose(bboxes: BoxesArray, axis: int, **kwargs) -> BoxesArray:
if not len(bboxes):
return bboxes
if axis not in {0, 1}:
raise ValueError(f"Invalid axis value {axis}. Axis must be either 0 or 1")
msg = "Axis must be either 0 or 1."
raise ValueError(msg)

if axis == 0:
bboxes = bboxes[:, [1, 0, 3, 2]]
Expand Down
1 change: 1 addition & 0 deletions albumentations/augmentations/geometric/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def apply_to_keypoints(
) -> KeypointsInternalType:
height = params["rows"]
width = params["cols"]

scale = max_size / min([height, width])
return F.keypoints_scale(keypoints, scale, scale)

Expand Down
6 changes: 4 additions & 2 deletions albumentations/augmentations/geometric/rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class RandomRotate90(DualTransform):
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)

def apply(self, img: np.ndarray, factor: float = 0, **params: Any) -> np.ndarray:
"""Args:
factor (int): number of times the input will be rotated by 90 degrees.
Expand Down Expand Up @@ -158,7 +160,7 @@ def apply_to_bboxes(
) -> BBoxesInternalType:
bboxes = F.bboxes_rotate(bboxes, angle=angle, method=self.rotate_method, rows=rows, cols=cols)
if self.crop_border:
bboxes = FCrops.bboxes_crop(
return FCrops.bboxes_crop(
bboxes,
x_min=x_min,
y_min=y_min,
Expand All @@ -183,7 +185,7 @@ def apply_to_keypoints(
):
keypoints_out = F.keypoints_rotate(keypoints, angle, rows, cols, **params)
if self.crop_border:
keypoints_out = FCrops.crop_keypoints_by_coords(keypoints_out, (x_min, y_min, x_max, y_max))
return FCrops.crop_keypoints_by_coords(keypoints_out, (x_min, y_min, x_max, y_max))
return keypoints_out

@staticmethod
Expand Down
23 changes: 16 additions & 7 deletions albumentations/augmentations/geometric/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,11 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A

# top left -- no changes needed, just use jitter
# top right
points[1, 0] = 1.0 - points[1, 0] # width = 1.0 - jitter
points[1, 0] = 1.0 - points[1, 0] # w = 1.0 - jitter
# bottom right
points[2] = 1.0 - points[2] # width = 1.0 - jitt
points[2] = 1.0 - points[2] # w = 1.0 - jitt
# bottom left
points[3, 1] = 1.0 - points[3, 1] # height = 1.0 - jitter
points[3, 1] = 1.0 - points[3, 1] # h = 1.0 - jitter

points[:, 0] *= width
points[:, 1] *= height
Expand Down Expand Up @@ -738,7 +738,7 @@ def apply(
return F.warp_affine(
img,
matrix,
interpolation=self.interpolation,
interpolation=cast(int, self.interpolation),
cval=self.cval,
mode=self.mode,
output_shape=output_shape,
Expand Down Expand Up @@ -775,10 +775,16 @@ def apply_to_keypoints(
self,
keypoints: KeypointsInternalType,
matrix: Optional[skimage.transform.ProjectiveTransform] = None,
scale: Optional[dict] = None,
scale: Optional[Dict[str, Any]] = None,
**params: Any,
) -> KeypointsInternalType:
assert scale is not None and matrix is not None
if scale is None:
msg = "Expected scale to be provided, but got None."
raise ValueError(msg)
if matrix is None:
msg = "Expected matrix to be provided, but got None."
raise ValueError(msg)

return F.keypoints_affine(keypoints, matrix=matrix, scale=scale)

@property
Expand Down Expand Up @@ -1091,6 +1097,7 @@ class PadIfNeeded(DualTransform):

class PositionType(Enum):
"""Enumerates the types of positions for placing an object within a container.
This Enum class is utilized to define specific anchor positions that an object can
assume relative to a container. It's particularly useful in image processing, UI layout,
and graphic design to specify the alignment and positioning of elements.
Expand All @@ -1102,6 +1109,7 @@ class PositionType(Enum):
BOTTOM_LEFT (str): Specifies that the object should be placed at the bottom-left corner.
BOTTOM_RIGHT (str): Specifies that the object should be placed at the bottom-right corner.
RANDOM (str): Indicates that the object's position should be determined randomly.
"""

CENTER = "center"
Expand Down Expand Up @@ -1191,7 +1199,7 @@ def update_params(self, params: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]
)
return params

def apply(
def apply_to_mask(
self,
img: np.ndarray,
pad_top: int = 0,
Expand Down Expand Up @@ -1571,6 +1579,7 @@ def __init__(
p: float = 0.5,
):
super().__init__(always_apply, p)

self.num_steps = num_steps
self.distort_limit = to_tuple(distort_limit)
self.interpolation = interpolation
Expand Down

0 comments on commit 6e70334

Please sign in to comment.