diff --git a/keras_nlp/api/bounding_box/__init__.py b/keras_nlp/api/bounding_box/__init__.py index 18be1cd9a..8488f76e6 100644 --- a/keras_nlp/api/bounding_box/__init__.py +++ b/keras_nlp/api/bounding_box/__init__.py @@ -18,6 +18,19 @@ """ from keras_nlp.src.bounding_box.converters import convert_format +from keras_nlp.src.bounding_box.formats import CENTER_XYWH +from keras_nlp.src.bounding_box.formats import REL_XYWH +from keras_nlp.src.bounding_box.formats import REL_XYXY +from keras_nlp.src.bounding_box.formats import REL_YXYX +from keras_nlp.src.bounding_box.formats import XYWH +from keras_nlp.src.bounding_box.formats import XYXY +from keras_nlp.src.bounding_box.formats import YXYX +from keras_nlp.src.bounding_box.iou import compute_ciou +from keras_nlp.src.bounding_box.iou import compute_iou from keras_nlp.src.bounding_box.to_dense import to_dense from keras_nlp.src.bounding_box.to_ragged import to_ragged +from keras_nlp.src.bounding_box.utils import as_relative +from keras_nlp.src.bounding_box.utils import clip_boxes +from keras_nlp.src.bounding_box.utils import clip_to_image +from keras_nlp.src.bounding_box.utils import is_relative from keras_nlp.src.bounding_box.validate_format import validate_format diff --git a/keras_nlp/src/bounding_box/formats.py b/keras_nlp/src/bounding_box/formats.py new file mode 100644 index 000000000..fda64a860 --- /dev/null +++ b/keras_nlp/src/bounding_box/formats.py @@ -0,0 +1,162 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +formats.py contains axis information for each supported format. +""" + +from keras_nlp.src.api_export import keras_nlp_export + + +@keras_nlp_export("keras_nlp.bounding_box.XYXY") +class XYXY: + """XYXY contains axis indices for the XYXY format. + + All values in the XYXY format should be absolute pixel values. + + The XYXY format consists of the following required indices: + + - LEFT: left of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + """ + + LEFT = 0 + TOP = 1 + RIGHT = 2 + BOTTOM = 3 + + +@keras_nlp_export("keras_nlp.bounding_box.REL_XYXY") +class REL_XYXY: + """REL_XYXY contains axis indices for the REL_XYXY format. + + REL_XYXY is like XYXY, but each value is relative to the width and height of + the origin image. Values are percentages of the origin images' width and + height respectively. + + The REL_XYXY format consists of the following required indices: + + - LEFT: left of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + """ + + LEFT = 0 + TOP = 1 + RIGHT = 2 + BOTTOM = 3 + + +@keras_nlp_export("keras_nlp.bounding_box.CENTER_XYWH") +class CENTER_XYWH: + """CENTER_XYWH contains axis indices for the CENTER_XYWH format. + + All values in the CENTER_XYWH format should be absolute pixel values. + + The CENTER_XYWH format consists of the following required indices: + + - X: X coordinate of the center of the bounding box + - Y: Y coordinate of the center of the bounding box + - WIDTH: width of the bounding box + - HEIGHT: height of the bounding box + """ + + X = 0 + Y = 1 + WIDTH = 2 + HEIGHT = 3 + + +@keras_nlp_export("keras_nlp.bounding_box.XYWH") +class XYWH: + """XYWH contains axis indices for the XYWH format. + + All values in the XYWH format should be absolute pixel values. + + The XYWH format consists of the following required indices: + + - X: X coordinate of the left of the bounding box + - Y: Y coordinate of the top of the bounding box + - WIDTH: width of the bounding box + - HEIGHT: height of the bounding box + """ + + X = 0 + Y = 1 + WIDTH = 2 + HEIGHT = 3 + + +@keras_nlp_export("keras_nlp.bounding_box.REL_XYWH") +class REL_XYWH: + """REL_XYWH contains axis indices for the XYWH format. + + REL_XYXY is like XYWH, but each value is relative to the width and height of + the origin image. Values are percentages of the origin images' width and + height respectively. + + - X: X coordinate of the left of the bounding box + - Y: Y coordinate of the top of the bounding box + - WIDTH: width of the bounding box + - HEIGHT: height of the bounding box + """ + + X = 0 + Y = 1 + WIDTH = 2 + HEIGHT = 3 + + +@keras_nlp_export("keras_nlp.bounding_box.YXYX") +class YXYX: + """YXYX contains axis indices for the YXYX format. + + All values in the YXYX format should be absolute pixel values. + + The YXYX format consists of the following required indices: + + - TOP: top of the bounding box + - LEFT: left of the bounding box + - BOTTOM: bottom of the bounding box + - RIGHT: right of the bounding box + """ + + TOP = 0 + LEFT = 1 + BOTTOM = 2 + RIGHT = 3 + + +@keras_nlp_export("keras_nlp.bounding_box.REL_YXYX") +class REL_YXYX: + """REL_YXYX contains axis indices for the REL_YXYX format. + + REL_YXYX is like YXYX, but each value is relative to the width and height of + the origin image. Values are percentages of the origin images' width and + height respectively. + + The REL_YXYX format consists of the following required indices: + + - TOP: top of the bounding box + - LEFT: left of the bounding box + - BOTTOM: bottom of the bounding box + - RIGHT: right of the bounding box + """ + + TOP = 0 + LEFT = 1 + BOTTOM = 2 + RIGHT = 3 diff --git a/keras_nlp/src/bounding_box/iou.py b/keras_nlp/src/bounding_box/iou.py new file mode 100644 index 000000000..46ea2a34b --- /dev/null +++ b/keras_nlp/src/bounding_box/iou.py @@ -0,0 +1,263 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains functions to compute ious of bounding boxes.""" +import math + +import keras +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.bounding_box.converters import convert_format +from keras_nlp.src.bounding_box.utils import as_relative +from keras_nlp.src.bounding_box.utils import is_relative + + +def _compute_area(box): + """Computes area for bounding boxes + + Args: + box: [N, 4] or [batch_size, N, 4] float Tensor, either batched + or unbatched boxes. + Returns: + a float Tensor of [N] or [batch_size, N] + """ + y_min, x_min, y_max, x_max = ops.split(box[..., :4], 4, axis=-1) + return ops.squeeze((y_max - y_min) * (x_max - x_min), axis=-1) + + +def _compute_intersection(boxes1, boxes2): + """Computes intersection area between two sets of boxes. + + Args: + boxes1: [N, 4] or [batch_size, N, 4] float Tensor boxes. + boxes2: [M, 4] or [batch_size, M, 4] float Tensor boxes. + Returns: + a [N, M] or [batch_size, N, M] float Tensor. + """ + y_min1, x_min1, y_max1, x_max1 = ops.split(boxes1[..., :4], 4, axis=-1) + y_min2, x_min2, y_max2, x_max2 = ops.split(boxes2[..., :4], 4, axis=-1) + boxes2_rank = len(boxes2.shape) + perm = [1, 0] if boxes2_rank == 2 else [0, 2, 1] + # [N, M] or [batch_size, N, M] + intersect_ymax = ops.minimum(y_max1, ops.transpose(y_max2, perm)) + intersect_ymin = ops.maximum(y_min1, ops.transpose(y_min2, perm)) + intersect_xmax = ops.minimum(x_max1, ops.transpose(x_max2, perm)) + intersect_xmin = ops.maximum(x_min1, ops.transpose(x_min2, perm)) + + intersect_height = intersect_ymax - intersect_ymin + intersect_width = intersect_xmax - intersect_xmin + zeros_t = ops.cast(0, intersect_height.dtype) + intersect_height = ops.maximum(zeros_t, intersect_height) + intersect_width = ops.maximum(zeros_t, intersect_width) + + return intersect_height * intersect_width + + +@keras_nlp_export("keras_nlp.bounding_box.compute_iou") +def compute_iou( + boxes1, + boxes2, + bounding_box_format, + use_masking=False, + mask_val=-1, + images=None, + image_shape=None, +): + """Computes a lookup table vector containing the ious for a given set boxes. + + The lookup vector is to be indexed by [`boxes1_index`,`boxes2_index`] if + boxes are unbatched and by [`batch`, `boxes1_index`,`boxes2_index`] if the + boxes are batched. + + The users can pass `boxes1` and `boxes2` to be different ranks. For example: + 1) `boxes1`: [batch_size, M, 4], `boxes2`: [batch_size, N, 4] -> return + [batch_size, M, N]. + 2) `boxes1`: [batch_size, M, 4], `boxes2`: [N, 4] -> return + [batch_size, M, N] + 3) `boxes1`: [M, 4], `boxes2`: [batch_size, N, 4] -> return + [batch_size, M, N] + 4) `boxes1`: [M, 4], `boxes2`: [N, 4] -> return [M, N] + + Args: + boxes1: a list of bounding boxes in 'corners' format. Can be batched or + unbatched. + boxes2: a list of bounding boxes in 'corners' format. Can be batched or + unbatched. + bounding_box_format: a case-insensitive string which is one of `"xyxy"`, + `"rel_xyxy"`, `"xyWH"`, `"center_xyWH"`, `"yxyx"`, `"rel_yxyx"`. + For detailed information on the supported format, see the + [KerasCV bounding box documentation](https://keras.io/api/keras_cv/bounding_box/formats/). + use_masking: whether masking will be applied. This will mask all `boxes1` + or `boxes2` that have values less than 0 in all its 4 dimensions. + Default to `False`. + mask_val: int to mask those returned IOUs if the masking is True, defaults + to -1. + + Returns: + iou_lookup_table: a vector containing the pairwise ious of boxes1 and + boxes2. + """ # noqa: E501 + + boxes1_rank = len(boxes1.shape) + boxes2_rank = len(boxes2.shape) + + if boxes1_rank not in [2, 3]: + raise ValueError( + "compute_iou() expects boxes1 to be batched, or to be unbatched. " + f"Received len(boxes1.shape)={boxes1_rank}, " + f"len(boxes2.shape)={boxes2_rank}. Expected either " + "len(boxes1.shape)=2 AND or len(boxes1.shape)=3." + ) + if boxes2_rank not in [2, 3]: + raise ValueError( + "compute_iou() expects boxes2 to be batched, or to be unbatched. " + f"Received len(boxes1.shape)={boxes1_rank}, " + f"len(boxes2.shape)={boxes2_rank}. Expected either " + "len(boxes2.shape)=2 AND or len(boxes2.shape)=3." + ) + + target_format = "yxyx" + if is_relative(bounding_box_format): + target_format = as_relative(target_format) + + boxes1 = convert_format( + boxes1, + source=bounding_box_format, + target=target_format, + images=images, + image_shape=image_shape, + ) + + boxes2 = convert_format( + boxes2, + source=bounding_box_format, + target=target_format, + images=images, + image_shape=image_shape, + ) + + intersect_area = _compute_intersection(boxes1, boxes2) + boxes1_area = _compute_area(boxes1) + boxes2_area = _compute_area(boxes2) + boxes2_area_rank = len(boxes2_area.shape) + boxes2_axis = 1 if (boxes2_area_rank == 2) else 0 + boxes1_area = ops.expand_dims(boxes1_area, axis=-1) + boxes2_area = ops.expand_dims(boxes2_area, axis=boxes2_axis) + union_area = boxes1_area + boxes2_area - intersect_area + res = ops.divide(intersect_area, union_area + keras.backend.epsilon()) + + if boxes1_rank == 2: + perm = [1, 0] + else: + perm = [0, 2, 1] + + if not use_masking: + return res + + mask_val_t = ops.cast(mask_val, res.dtype) * ops.ones_like(res) + boxes1_mask = ops.less(ops.max(boxes1, axis=-1, keepdims=True), 0.0) + boxes2_mask = ops.less(ops.max(boxes2, axis=-1, keepdims=True), 0.0) + background_mask = ops.logical_or( + boxes1_mask, ops.transpose(boxes2_mask, perm) + ) + iou_lookup_table = ops.where(background_mask, mask_val_t, res) + return iou_lookup_table + + +@keras_nlp_export("keras_nlp.bounding_box.compute_ciou") +def compute_ciou(boxes1, boxes2, bounding_box_format): + """ + Computes the Complete IoU (CIoU) between two bounding boxes or between + two batches of bounding boxes. + + CIoU loss is an extension of GIoU loss, which further improves the IoU + optimization for object detection. CIoU loss not only penalizes the + bounding box coordinates but also considers the aspect ratio and center + distance of the boxes. The length of the last dimension should be 4 to + represent the bounding boxes. + + Args: + box1 (tensor): tensor representing the first bounding box with + shape (..., 4). + box2 (tensor): tensor representing the second bounding box with + shape (..., 4). + bounding_box_format: a case-insensitive string (for example, "xyxy"). + Each bounding box is defined by these 4 values. For detailed + information on the supported formats, see the [KerasCV bounding box + documentation](https://keras.io/api/keras_cv/bounding_box/formats/). + + Returns: + tensor: The CIoU distance between the two bounding boxes. + """ + target_format = "xyxy" + if is_relative(bounding_box_format): + target_format = as_relative(target_format) + + boxes1 = convert_format( + boxes1, source=bounding_box_format, target=target_format + ) + + boxes2 = convert_format( + boxes2, source=bounding_box_format, target=target_format + ) + + x_min1, y_min1, x_max1, y_max1 = ops.split(boxes1[..., :4], 4, axis=-1) + x_min2, y_min2, x_max2, y_max2 = ops.split(boxes2[..., :4], 4, axis=-1) + + width_1 = x_max1 - x_min1 + height_1 = y_max1 - y_min1 + keras.backend.epsilon() + width_2 = x_max2 - x_min2 + height_2 = y_max2 - y_min2 + keras.backend.epsilon() + + intersection_area = ops.maximum( + ops.minimum(x_max1, x_max2) - ops.maximum(x_min1, x_min2), 0 + ) * ops.maximum( + ops.minimum(y_max1, y_max2) - ops.maximum(y_min1, y_min2), 0 + ) + union_area = ( + width_1 * height_1 + + width_2 * height_2 + - intersection_area + + keras.backend.epsilon() + ) + iou = ops.squeeze( + ops.divide(intersection_area, union_area + keras.backend.epsilon()), + axis=-1, + ) + + convex_width = ops.maximum(x_max1, x_max2) - ops.minimum(x_min1, x_min2) + convex_height = ops.maximum(y_max1, y_max2) - ops.minimum(y_min1, y_min2) + convex_diagonal_squared = ops.squeeze( + convex_width**2 + convex_height**2 + keras.backend.epsilon(), + axis=-1, + ) + centers_distance_squared = ops.squeeze( + ((x_min1 + x_max1) / 2 - (x_min2 + x_max2) / 2) ** 2 + + ((y_min1 + y_max1) / 2 - (y_min2 + y_max2) / 2) ** 2, + axis=-1, + ) + + v = ops.squeeze( + ops.power( + (4 / math.pi**2) + * (ops.arctan(width_2 / height_2) - ops.arctan(width_1 / height_1)), + 2, + ), + axis=-1, + ) + alpha = v / (v - iou + (1 + keras.backend.epsilon())) + + return iou - ( + centers_distance_squared / convex_diagonal_squared + v * alpha + ) diff --git a/keras_nlp/src/bounding_box/iou_test.py b/keras_nlp/src/bounding_box/iou_test.py new file mode 100644 index 000000000..ffd3b61cf --- /dev/null +++ b/keras_nlp/src/bounding_box/iou_test.py @@ -0,0 +1,161 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for iou functions.""" + +import numpy as np + +from keras_nlp.src.bounding_box import iou as iou_lib +from keras_nlp.src.tests.test_case import TestCase + + +class IoUTest(TestCase): + def test_compute_single_iou(self): + bb1 = np.array([[100, 101, 200, 201]]) + bb1_off_by_1 = np.array([[101, 102, 201, 202]]) + # area of bb1 and bb1_off_by_1 are each 10000. + # intersection area is 99*99=9801 + # iou=9801/(2*10000 - 9801)=0.96097656633 + self.assertAllClose( + iou_lib.compute_iou(bb1, bb1_off_by_1, "yxyx")[0], [0.96097656633] + ) + + def test_compute_iou(self): + bb1 = [100, 101, 200, 201] + bb1_off_by_1_pred = [101, 102, 201, 202] + iou_bb1_bb1_off = 0.96097656633 + top_left_bounding_box = [0, 2, 1, 3] + far_away_box = [1300, 1400, 1500, 1401] + another_far_away_pred = [1000, 1400, 1200, 1401] + + # Rows represent predictions, columns ground truths + expected_result = np.array( + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + dtype=np.float32, + ) + + sample_y_true = np.array([bb1, top_left_bounding_box, far_away_box]) + sample_y_pred = np.array( + [bb1_off_by_1_pred, top_left_bounding_box, another_far_away_pred], + ) + + result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(expected_result, result) + + def test_batched_compute_iou(self): + bb1 = [100, 101, 200, 201] + bb1_off_by_1_pred = [101, 102, 201, 202] + iou_bb1_bb1_off = 0.96097656633 + top_left_bounding_box = [0, 2, 1, 3] + far_away_box = [1300, 1400, 1500, 1401] + another_far_away_pred = [1000, 1400, 1200, 1401] + + # Rows represent predictions, columns ground truths + expected_result = np.array( + [ + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + ], + ) + + sample_y_true = np.array( + [ + [bb1, top_left_bounding_box, far_away_box], + [bb1, top_left_bounding_box, far_away_box], + ], + ) + sample_y_pred = np.array( + [ + [ + bb1_off_by_1_pred, + top_left_bounding_box, + another_far_away_pred, + ], + [ + bb1_off_by_1_pred, + top_left_bounding_box, + another_far_away_pred, + ], + ], + ) + + result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(expected_result, result) + + def test_batched_boxes1_unbatched_boxes2(self): + bb1 = [100, 101, 200, 201] + bb1_off_by_1_pred = [101, 102, 201, 202] + iou_bb1_bb1_off = 0.96097656633 + top_left_bounding_box = [0, 2, 1, 3] + far_away_box = [1300, 1400, 1500, 1401] + another_far_away_pred = [1000, 1400, 1200, 1401] + + # Rows represent predictions, columns ground truths + expected_result = np.array( + [ + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + ], + ) + + sample_y_true = np.array( + [ + [bb1, top_left_bounding_box, far_away_box], + [bb1, top_left_bounding_box, far_away_box], + ], + ) + sample_y_pred = np.array( + [bb1_off_by_1_pred, top_left_bounding_box, another_far_away_pred], + ) + + result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(expected_result, result) + + def test_unbatched_boxes1_batched_boxes2(self): + bb1 = [100, 101, 200, 201] + bb1_off_by_1_pred = [101, 102, 201, 202] + iou_bb1_bb1_off = 0.96097656633 + top_left_bounding_box = [0, 2, 1, 3] + far_away_box = [1300, 1400, 1500, 1401] + another_far_away_pred = [1000, 1400, 1200, 1401] + + # Rows represent predictions, columns ground truths + expected_result = np.array( + [ + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + ], + ) + + sample_y_true = np.array( + [ + [bb1, top_left_bounding_box, far_away_box], + ], + ) + sample_y_pred = np.array( + [ + [ + bb1_off_by_1_pred, + top_left_bounding_box, + another_far_away_pred, + ], + [ + bb1_off_by_1_pred, + top_left_bounding_box, + another_far_away_pred, + ], + ], + ) + + result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(expected_result, result) diff --git a/keras_nlp/src/bounding_box/utils.py b/keras_nlp/src/bounding_box/utils.py new file mode 100644 index 000000000..a96c284a6 --- /dev/null +++ b/keras_nlp/src/bounding_box/utils.py @@ -0,0 +1,194 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions for working with bounding boxes.""" + +from keras import ops + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.bounding_box import converters +from keras_nlp.src.bounding_box.formats import XYWH + + +@keras_nlp_export("keras_nlp.bounding_box.is_relative") +def is_relative(bounding_box_format): + """A util to check if a bounding box format uses relative coordinates""" + if bounding_box_format.lower() not in converters.TO_XYXY_CONVERTERS: + raise ValueError( + "`is_relative()` received an unsupported format for the argument " + f"`bounding_box_format`. `bounding_box_format` should be one of " + f"{converters.TO_XYXY_CONVERTERS.keys()}. " + f"Got bounding_box_format={bounding_box_format}" + ) + + return bounding_box_format.startswith("rel") + + +@keras_nlp_export("keras_nlp.bounding_box.as_relative") +def as_relative(bounding_box_format): + """A util to get the relative equivalent of a provided bounding box format. + + If the specified format is already a relative format, + it will be returned unchanged. + """ + + if not is_relative(bounding_box_format): + return "rel_" + bounding_box_format + + return bounding_box_format + + +def _relative_area(boxes, bounding_box_format): + boxes = converters.convert_format( + boxes, + source=bounding_box_format, + target="rel_xywh", + ) + widths = boxes[..., XYWH.WIDTH] + heights = boxes[..., XYWH.HEIGHT] + # handle corner case where shear performs a full inversion. + return ops.where( + ops.logical_and(widths > 0, heights > 0), widths * heights, 0.0 + ) + + +@keras_nlp_export("keras_nlp.bounding_box.clip_to_image") +def clip_to_image( + bounding_boxes, bounding_box_format, images=None, image_shape=None +): + """clips bounding boxes to image boundaries. + + `clip_to_image()` clips bounding boxes that have coordinates out of bounds + of an image down to the boundaries of the image. This is done by converting + the bounding box to relative formats, then clipping them to the `[0, 1]` + range. Additionally, bounding boxes that end up with a zero area have their + class ID set to -1, indicating that there is no object present in them. + + Args: + bounding_boxes: bounding box tensor to clip. + bounding_box_format: the KerasCV bounding box format the bounding boxes + are in. + images: list of images to clip the bounding boxes to. + image_shape: the shape of the images to clip the bounding boxes to. + """ + boxes, classes = bounding_boxes["boxes"], bounding_boxes["classes"] + + boxes = converters.convert_format( + boxes, + source=bounding_box_format, + target="rel_xyxy", + images=images, + image_shape=image_shape, + ) + boxes, classes, images, squeeze = _format_inputs(boxes, classes, images) + x1, y1, x2, y2 = ops.split(boxes, 4, axis=-1) + clipped_bounding_boxes = ops.concatenate( + [ + ops.clip(x1, 0, 1), + ops.clip(y1, 0, 1), + ops.clip(x2, 0, 1), + ops.clip(y2, 0, 1), + ], + axis=-1, + ) + areas = _relative_area( + clipped_bounding_boxes, bounding_box_format="rel_xyxy" + ) + clipped_bounding_boxes = converters.convert_format( + clipped_bounding_boxes, + source="rel_xyxy", + target=bounding_box_format, + images=images, + image_shape=image_shape, + ) + clipped_bounding_boxes = ops.where( + ops.expand_dims(areas > 0.0, axis=-1), clipped_bounding_boxes, -1.0 + ) + classes = ops.where(areas > 0.0, classes, -1) + nan_indices = ops.any(ops.isnan(clipped_bounding_boxes), axis=-1) + classes = ops.where(nan_indices, -1, classes) + + # TODO update dict and return + clipped_bounding_boxes, classes = _format_outputs( + clipped_bounding_boxes, classes, squeeze + ) + + bounding_boxes.update({"boxes": clipped_bounding_boxes, "classes": classes}) + + return bounding_boxes + + +@keras_nlp_export("keras_nlp.bounding_box.clip_boxes") +def clip_boxes(boxes, image_shape): + """Clip boxes to the boundaries of the image shape""" + if boxes.shape[-1] != 4: + raise ValueError( + "boxes.shape[-1] is {:d}, but must be 4.".format(boxes.shape[-1]) + ) + + if isinstance(image_shape, list) or isinstance(image_shape, tuple): + height, width, _ = image_shape + max_length = ops.stack([height, width, height, width], axis=-1) + else: + image_shape = ops.cast(image_shape, dtype=boxes.dtype) + height = image_shape[0] + width = image_shape[1] + max_length = ops.stack([height, width, height, width], axis=-1) + + clipped_boxes = ops.maximum(ops.minimum(boxes, max_length), 0.0) + return clipped_boxes + + +def _format_inputs(boxes, classes, images): + boxes_rank = len(boxes.shape) + if boxes_rank > 3: + raise ValueError( + "Expected len(boxes.shape)=2, or len(boxes.shape)=3, got " + f"len(boxes.shape)={boxes_rank}" + ) + boxes_includes_batch = boxes_rank == 3 + # Determine if images needs an expand_dims() call + if images is not None: + images_rank = len(images.shape) + if images_rank > 4: + raise ValueError( + "Expected len(images.shape)=2, or len(images.shape)=3, got " + f"len(images.shape)={images_rank}" + ) + images_include_batch = images_rank == 4 + if boxes_includes_batch != images_include_batch: + raise ValueError( + "clip_to_image() expects both boxes and images to be batched, " + "or both boxes and images to be unbatched. Received " + f"len(boxes.shape)={boxes_rank}, " + f"len(images.shape)={images_rank}. Expected either " + "len(boxes.shape)=2 AND len(images.shape)=3, or " + "len(boxes.shape)=3 AND len(images.shape)=4." + ) + if not images_include_batch: + images = ops.expand_dims(images, axis=0) + + if not boxes_includes_batch: + return ( + ops.expand_dims(boxes, axis=0), + ops.expand_dims(classes, axis=0), + images, + True, + ) + return boxes, classes, images, False + + +def _format_outputs(boxes, classes, squeeze): + if squeeze: + return ops.squeeze(boxes, axis=0), ops.squeeze(classes, axis=0) + return boxes, classes diff --git a/keras_nlp/src/bounding_box/utils_test.py b/keras_nlp/src/bounding_box/utils_test.py new file mode 100644 index 000000000..cf6143684 --- /dev/null +++ b/keras_nlp/src/bounding_box/utils_test.py @@ -0,0 +1,166 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from keras import ops + +from keras_nlp.src.bounding_box import utils +from keras_nlp.src.tests.test_case import TestCase + + +class BoundingBoxUtilTest(TestCase): + def test_clip_to_image_standard(self): + # Test xyxy format unbatched + height = 256 + width = 256 + bounding_boxes = { + "boxes": np.array([[200, 200, 400, 400], [100, 100, 300, 300]]), + "classes": np.array([0, 0]), + } + image = ops.ones(shape=(height, width, 3)) + bounding_boxes = utils.clip_to_image( + bounding_boxes, bounding_box_format="xyxy", images=image + ) + boxes = bounding_boxes["boxes"] + self.assertAllGreaterEqual(ops.convert_to_numpy(boxes), 0) + ( + x1, + y1, + x2, + y2, + ) = ops.split(boxes, 4, axis=1) + self.assertAllLessEqual( + ops.convert_to_numpy(ops.concatenate([x1, x2], axis=1)), width + ) + self.assertAllLessEqual( + ops.convert_to_numpy(ops.concatenate([y1, y2], axis=1)), height + ) + # Test relative format batched + image = ops.ones(shape=(1, height, width, 3)) + + bounding_boxes = { + "boxes": np.array([[[0.2, -1, 1.2, 0.3], [0.4, 1.5, 0.2, 0.3]]]), + "classes": np.array([[0, 0]]), + } + bounding_boxes = utils.clip_to_image( + bounding_boxes, bounding_box_format="rel_xyxy", images=image + ) + boxes = bounding_boxes["boxes"] + self.assertAllLessEqual(ops.convert_to_numpy(boxes), 1) + + def test_clip_to_image_filters_fully_out_bounding_boxes(self): + # Test xyxy format unbatched + height = 256 + width = 256 + bounding_boxes = { + "boxes": np.array([[257, 257, 400, 400], [100, 100, 300, 300]]), + "classes": np.array([0, 0]), + } + image = ops.ones(shape=(height, width, 3)) + bounding_boxes = utils.clip_to_image( + bounding_boxes, bounding_box_format="xyxy", images=image + ) + + self.assertAllEqual( + bounding_boxes["boxes"], + np.array([[-1, -1, -1, -1], [100, 100, 256, 256]]), + ), + self.assertAllEqual( + bounding_boxes["classes"], + np.array([-1, 0]), + ) + + def test_clip_to_image_filters_fully_out_bounding_boxes_negative_area(self): + # Test xyxy format unbatched + height = 256 + width = 256 + bounding_boxes = { + "boxes": np.array([[110, 120, 100, 100], [100, 100, 300, 300]]), + "classes": np.array([0, 0]), + } + image = ops.ones(shape=(height, width, 3)) + bounding_boxes = utils.clip_to_image( + bounding_boxes, bounding_box_format="xyxy", images=image + ) + self.assertAllEqual( + bounding_boxes["boxes"], + np.array( + [ + [ + -1, + -1, + -1, + -1, + ], + [ + 100, + 100, + 256, + 256, + ], + ] + ), + ) + self.assertAllEqual( + bounding_boxes["classes"], + np.array([-1, 0]), + ) + + def test_clip_to_image_filters_nans(self): + # Test xyxy format unbatched + height = 256 + width = 256 + bounding_boxes = { + "boxes": np.array( + [[0, float("NaN"), 100, 100], [100, 100, 300, 300]] + ), + "classes": np.array([0, 0]), + } + image = ops.ones(shape=(height, width, 3)) + bounding_boxes = utils.clip_to_image( + bounding_boxes, bounding_box_format="xyxy", images=image + ) + self.assertAllEqual( + bounding_boxes["boxes"], + np.array( + [ + [ + -1, + -1, + -1, + -1, + ], + [ + 100, + 100, + 256, + 256, + ], + ] + ), + ) + self.assertAllEqual( + bounding_boxes["classes"], + np.array([-1, 0]), + ) + + def test_is_relative_util(self): + self.assertTrue(utils.is_relative("rel_xyxy")) + self.assertFalse(utils.is_relative("xyxy")) + + with self.assertRaises(ValueError): + _ = utils.is_relative("bad_format") + + def test_as_relative_util(self): + self.assertEqual(utils.as_relative("yxyx"), "rel_yxyx") + self.assertEqual(utils.as_relative("rel_xywh"), "rel_xywh") diff --git a/keras_nlp/src/bounding_box/validate_format_test.py b/keras_nlp/src/bounding_box/validate_format_test.py new file mode 100644 index 000000000..020279f33 --- /dev/null +++ b/keras_nlp/src/bounding_box/validate_format_test.py @@ -0,0 +1,47 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_nlp.src.bounding_box import validate_format +from keras_nlp.src.tests.test_case import TestCase + + +class ValidateTest(TestCase): + def test_raises_nondict(self): + with self.assertRaisesRegex( + ValueError, "Expected `bounding_boxes` to be a dictionary, got " + ): + validate_format.validate_format(np.ones((4, 3, 6))) + + def test_mismatch_dimensions(self): + with self.assertRaisesRegex( + ValueError, + "Expected `boxes` and `classes` to have matching dimensions", + ): + validate_format.validate_format( + {"boxes": np.ones((4, 3, 6)), "classes": np.ones((4, 6))} + ) + + def test_bad_keys(self): + with self.assertRaisesRegex(ValueError, "containing keys"): + validate_format.validate_format( + { + "box": [ + 1, + 2, + 3, + ], + "class": [1234], + } + )