Skip to content

Commit

Permalink
Fix core/bbox_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Dipet committed Mar 14, 2024
1 parent 19bcc9f commit 86ed8c4
Showing 1 changed file with 68 additions and 31 deletions.
99 changes: 68 additions & 31 deletions albumentations/core/bbox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from .transforms_interface import (
from .types import (
BBoxesInternalType,
BoxesArray,
BoxInternalType,
Expand All @@ -25,6 +25,8 @@
"use_bboxes_ndarray",
]

BBOX_WITH_LABEL_SHAPE = 5


def split_bboxes_targets(bboxes: Sequence[BoxType]) -> Tuple[np.ndarray, List[Any]]:
bbox_array, targets = [], []
Expand Down Expand Up @@ -101,6 +103,7 @@ class BboxParams(Params):
less than this value will be removed. Default: 0.0.
check_each_transform (bool): if `True`, then bboxes will be checked after each dual transform.
Default: `True`
"""

def __init__(
Expand All @@ -113,7 +116,7 @@ def __init__(
min_height: float = 0.0,
check_each_transform: bool = True,
):
super(BboxParams, self).__init__(format, label_fields)
super().__init__(format, label_fields)
self.min_area = min_area
self.min_visibility = min_visibility
self.min_width = min_width
Expand Down Expand Up @@ -164,19 +167,19 @@ def default_data_name(self) -> str:
def ensure_data_valid(self, data: Dict[str, Any]) -> None:
for data_name in self.data_fields:
data_exists = data_name in data and len(data[data_name])
if data_exists and len(data[data_name][0]) < 5:
if self.params.label_fields is None:
raise ValueError(
"Please specify 'label_fields' in 'bbox_params' or add labels to the end of bbox "
"because bboxes must have labels"
)
if self.params.label_fields:
if not all(i in data for i in self.params.label_fields):
raise ValueError("Your 'label_fields' are not valid - them must have same names as params in dict")
if data_exists and len(data[data_name][0]) < BBOX_WITH_LABEL_SHAPE and self.params.label_fields is None:
msg = (
"Please specify 'label_fields' in 'bbox_params' or add labels to the end of bbox "
"because bboxes must have labels"
)
raise ValueError(msg)
if self.params.label_fields and not all(i in data for i in self.params.label_fields):
msg = "Your 'label_fields' are not valid - them must have same names as params in dict"
raise ValueError(msg)

def filter(self, data: BoxesArray, rows: int, cols: int, target_name: str) -> BoxesArray:
self.params: BboxParams
data = filter_bboxes(
return filter_bboxes(
data,
rows,
cols,
Expand All @@ -186,8 +189,6 @@ def filter(self, data: BoxesArray, rows: int, cols: int, target_name: str) -> Bo
min_height=self.params.min_height,
)

return data

def check(self, data: BoxesArray, rows: int, cols: int) -> None:
check_bboxes(data)

Expand All @@ -209,7 +210,18 @@ def normalize_bboxes_np(bboxes: BoxesArray, rows: Union[int, float], cols: Union
Returns:
BoxesArray: Normalized bounding boxes `[(x_min, y_min, x_max, y_max)]`.
Raises:
ValueError: If rows or cols is less or equal zero
"""
if rows <= 0:
msg = "Argument rows must be positive integer"
raise ValueError(msg)
if cols <= 0:
msg = "Argument cols must be positive integer"
raise ValueError(msg)

if not len(bboxes):
return bboxes

Expand All @@ -221,17 +233,29 @@ def normalize_bboxes_np(bboxes: BoxesArray, rows: Union[int, float], cols: Union

@use_bboxes_ndarray(return_array=True)
def denormalize_bboxes_np(bboxes: BoxesArray, rows: Union[int, float], cols: Union[int, float]) -> BoxesArray:
"""Denormalize a list of bounding boxes.
"""Denormalize coordinates of a bounding boxes. Multiply x-coordinates by image width and y-coordinates
by image height.
This is an inverse operation for :func:`~albumentations.augmentations.core.bbox_utils.normalize_bboxes_np`.
Args:
bboxes (BoxesArray): Normalized bounding boxes `[(x_min, y_min, x_max, y_max)]`.
bboxes: Normalized bounding boxes `[(x_min, y_min, x_max, y_max)]`.
rows: Image height.
cols: Image width.
Returns:
BoxesArray: Denormalized bounding boxes `[(x_min, y_min, x_max, y_max)]`.
Raises:
ValueError: If rows or cols is less or equal zero
"""
if rows <= 0:
msg = "Argument rows must be positive integer"
raise ValueError(msg)
if cols <= 0:
msg = "Argument cols must be positive integer"
raise ValueError(msg)

if not len(bboxes):
return bboxes
bboxes_ = bboxes.copy().astype(float)
Expand All @@ -246,16 +270,15 @@ def calculate_bboxes_area(bboxes: BoxesArray, rows: int, cols: int) -> np.ndarra
"""Calculate the area of bounding boxes in (fractional) pixels.
Args:
bboxes (BoxesArray): A batch of bounding boxes in `albumentations` format.
rows (int): Image height
cols (int): Image width
bboxes: A batch of bounding boxes in `albumentations` format.
rows: Image height.
cols: Image width.
Returns:
numpy.ndarray: area in (fractional) pixels of the denormalized bounding boxes.
"""
bboxes_area = (bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] - bboxes[:, 1]) * cols * rows
return bboxes_area
return (bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] - bboxes[:, 1]) * cols * rows


@ensure_internal_format
Expand All @@ -266,11 +289,17 @@ def convert_bboxes_to_albumentations(
"""Convert a batch of bounding boxes from a format specified in `source_format` to the format used by albumentations
Args:
bboxes (BoxesArray): A batch of bounding boxes.
source_format (str):
rows (int):
cols (int):
check_validity (bool):
bboxes: A batch of bounding boxes.
source_format: format of the bounding box. Should be 'coco', 'pascal_voc', or 'yolo'.
check_validity: Check if all boxes are valid boxes.
rows: Image height.
cols: Image width.
Note:
The `coco` format of a bounding box looks like `(x_min, y_min, width, height)`, e.g. (97, 12, 150, 200).
The `pascal_voc` format of a bounding box looks like `(x_min, y_min, x_max, y_max)`, e.g. (97, 12, 247, 212).
The `yolo` format of a bounding box looks like `(x, y, width, height)`, e.g. (0.3, 0.1, 0.05, 0.07);
where `x`, `y` coordinates of the center of the box, all values normalized to 1 by image height and width.
Returns:
BoxesArray: A batch of bounding boxes in `albumentations` format.
Expand Down Expand Up @@ -313,24 +342,32 @@ def convert_bboxes_from_albumentations(
bboxes: BoxesArray, target_format: str, rows: int, cols: int, check_validity: bool = False
) -> BoxesArray:
"""Convert a list of bounding boxes from the format used by albumentations to a format, specified
in `target_format`.
in `target_format`.
Args:
bboxes: List of albumentation bounding box `(x_min, y_min, x_max, y_max)`.
bboxes: List of albumentations bounding box `(x_min, y_min, x_max, y_max)`.
target_format: required format of the output bounding box. Should be 'coco', 'pascal_voc' or 'yolo'.
rows: Image height.
cols: Image width.
check_validity: Check if all boxes are valid boxes.
Returns:
List of bounding boxes.
np.ndarray: A bounding box.
Note:
The `coco` format of a bounding box looks like `[x_min, y_min, width, height]`, e.g. [97, 12, 150, 200].
The `pascal_voc` format of a bounding box looks like `[x_min, y_min, x_max, y_max]`, e.g. [97, 12, 247, 212].
The `yolo` format of a bounding box looks like `[x, y, width, height]`, e.g. [0.3, 0.1, 0.05, 0.07].
Raises:
ValueError: if `target_format` is not equal to `coco`, `pascal_voc` or `yolo`.
"""
if not len(bboxes):
return bboxes
if target_format not in {"coco", "pascal_voc", "yolo"}:
raise ValueError(
f"Unknown target_format {target_format}. Supported formats are `coco`, `pascal_voc`, and `yolo`."
f"Unknown target_format {target_format}. Supported formats are: 'coco', 'pascal_voc' and 'yolo'"
)

if check_validity:
Expand Down Expand Up @@ -392,7 +429,7 @@ def filter_bboxes(
or whose area in pixels is under the threshold set by `min_area`. Also it crops boxes to final image size.
Args:
bboxes (BBoxesInternalType): List of albumentation bounding box `(x_min, y_min, x_max, y_max)`.
bboxes: List of albumentations bounding box `(x_min, y_min, x_max, y_max)`.
rows: Image height.
cols: Image width.
min_area: Minimum area of a bounding box. All bounding boxes whose visible area in pixels.
Expand Down

0 comments on commit 86ed8c4

Please sign in to comment.