diff --git a/keras_hub/src/models/retinanet/anchor_generator.py b/keras_hub/src/models/retinanet/anchor_generator.py index 04d5f7dc9..bb4698892 100644 --- a/keras_hub/src/models/retinanet/anchor_generator.py +++ b/keras_hub/src/models/retinanet/anchor_generator.py @@ -24,29 +24,31 @@ class AnchorGenerator(keras.layers.Layer): for larger objects. Args: - bounding_box_format (str): The format of the bounding boxes + bounding_box_format: str. The format of the bounding boxes to be generated. Expected to be a string like 'xyxy', 'xywh', etc. - min_level (int): Minimum level of the output feature pyramid. - max_level (int): Maximum level of the output feature pyramid. - num_scales (int): Number of intermediate scales added on each level. + min_level: int. Minimum level of the output feature pyramid. + max_level: int. Maximum level of the output feature pyramid. + num_scales: int. Number of intermediate scales added on each level. For example, num_scales=2 adds one additional intermediate anchor scale [2^0, 2^0.5] on each level. - aspect_ratios (list of float): Aspect ratios of anchors added on + aspect_ratios: List[float]. Aspect ratios of anchors added on each level. Each number indicates the ratio of width to height. - anchor_size (float): Scale of size of the base anchor relative to the + anchor_size: float. Scale of size of the base anchor relative to the feature stride 2^level. Call arguments: - images (Optional[Tensor]): An image tensor with shape `[B, H, W, C]` or - `[H, W, C]`. If provided, its shape will be used to determine anchor + inputs: An image tensor with shape `[B, H, W, C]` or + `[H, W, C]`. Its shape will be used to determine anchor sizes. Returns: Dict: A dictionary mapping feature levels - (e.g., 'P3', 'P4', etc.) to anchor boxes. Each entry contains a tensor - of shape `(H/stride * W/stride * num_anchors_per_location, 4)`, - where H and W are the height and width of the image, stride is 2^level, - and num_anchors_per_location is `num_scales * len(aspect_ratios)`. + (e.g., 'P3', 'P4', etc.) to anchor boxes. Each entry contains a + tensor of shape + `(H/stride * W/stride * num_anchors_per_location, 4)`, + where H and W are the height and width of the image, + stride is 2^level, and num_anchors_per_location is + `num_scales * len(aspect_ratios)`. Example: ```python @@ -81,8 +83,8 @@ def __init__( self.anchor_size = anchor_size self.built = True - def call(self, images): - images_shape = ops.shape(images) + def call(self, inputs): + images_shape = ops.shape(inputs) if len(images_shape) == 4: image_shape = images_shape[1:-1] else: @@ -147,8 +149,18 @@ def call(self, images): def compute_output_shape(self, input_shape): multilevel_boxes_shape = {} - for level in range(self.min_level, self.max_level + 1): - multilevel_boxes_shape[f"P{level}"] = (None, None, 4) + if len(input_shape) == 4: + image_height, image_width = input_shape[1:-1] + else: + image_height, image_width = input_shape[:-1] + + for i in range(self.min_level, self.max_level + 1): + multilevel_boxes_shape[f"P{i}"] = ( + (image_height // 2 ** (i)) + * (image_width // 2 ** (i)) + * self.anchors_per_location, + 4, + ) return multilevel_boxes_shape @property diff --git a/keras_hub/src/models/retinanet/anchor_generator_test.py b/keras_hub/src/models/retinanet/anchor_generator_test.py index 8b0669188..c843c32f2 100644 --- a/keras_hub/src/models/retinanet/anchor_generator_test.py +++ b/keras_hub/src/models/retinanet/anchor_generator_test.py @@ -1,3 +1,4 @@ +import numpy as np from absl.testing import parameterized from keras import ops @@ -7,6 +8,32 @@ class AnchorGeneratorTest(TestCase): + def test_layer_behaviors(self): + images_shape = (8, 128, 128, 3) + self.run_layer_test( + cls=AnchorGenerator, + init_kwargs={ + "bounding_box_format": "xyxy", + "min_level": 3, + "max_level": 7, + "num_scales": 3, + "aspect_ratios": [0.5, 1.0, 2.0], + "anchor_size": 8, + }, + input_data=np.random.uniform(size=images_shape), + expected_output_shape={ + "P3": (2304, 4), + "P4": (576, 4), + "P5": (144, 4), + "P6": (36, 4), + "P7": (9, 4), + }, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + run_training_check=False, + run_precision_checks=False, + ) + @parameterized.parameters( # Single scale anchor ("yxyx", 5, 5, 1, [1.0], 2.0, [64, 64]) @@ -86,7 +113,7 @@ def test_anchor_generator( anchor_size, ) images = ops.ones(shape=(1, image_shape[0], image_shape[1], 3)) - multilevel_boxes = anchor_generator(images=images) + multilevel_boxes = anchor_generator(images) for key in expected_boxes: expected_boxes[key] = ops.convert_to_tensor(expected_boxes[key]) expected_boxes[key] = convert_format( diff --git a/keras_hub/src/models/retinanet/feature_pyramid.py b/keras_hub/src/models/retinanet/feature_pyramid.py new file mode 100644 index 000000000..5c0bbb906 --- /dev/null +++ b/keras_hub/src/models/retinanet/feature_pyramid.py @@ -0,0 +1,373 @@ +import keras + + +class FeaturePyramid(keras.layers.Layer): + """A Feature Pyramid Network (FPN) layer. + + This implements the paper: + Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He, Bharath Hariharan, + and Serge Belongie. Feature Pyramid Networks for Object Detection. + (https://arxiv.org/pdf/1612.03144) + + Feature Pyramid Networks (FPNs) are basic components that are added to an + existing feature extractor (CNN) to combine features at different scales. + For the basic FPN, the inputs are features `Ci` from different levels of a + CNN, which is usually the last block for each level, where the feature is + scaled from the image by a factor of `1/2^i`. + + There is an output associated with each level in the basic FPN. The output + Pi at level `i` (corresponding to Ci) is given by performing a merge + operation on the outputs of: + + 1) a lateral operation on Ci (usually a conv2D layer with kernel = 1 and + strides = 1) + 2) a top-down upsampling operation from Pi+1 (except for the top most level) + + The final output of each level will also have a conv2D operation + (typically with kernel = 3 and strides = 1). + + The inputs to the layer should be a dict with int keys should match the + pyramid_levels, e.g. for `pyramid_levels` = [3,4,5], the expected input + dict should be `{P3:c3, P4:c4, P5:c5}`. + + The output of the layer will have same structures as the inputs, a dict with + extra coarser layers will be added based on the `max_level` provided. + keys and value for each of the level. + + Args: + min_level: int. The minimum level of the feature pyramid. + max_level: int. The maximum level of the feature pyramid. + num_filters: int. The number of filters in each feature map. + activation: string or `keras.activations`. The activation function + to be used in network. + Defaults to `"relu"`. + kernel_initializer: `str` or `keras.initializers` initializer. + The kernel initializer for the convolution layers. + Defaults to `"VarianceScaling"`. + bias_initializer: `str` or `keras.initializers` initializer. + The bias initializer for the convolution layers. + Defaults to `"zeros"`. + batch_norm_momentum: float. + The momentum for the batch normalization layers. + Defaults to `0.99`. + batch_norm_epsilon: float. + The epsilon for the batch normalization layers. + Defaults to `0.001`. + kernel_regularizer: `str` or `keras.regularizers` regularizer. + The kernel regularizer for the convolution layers. + Defaults to `None`. + bias_regularizer: `str` or `keras.regularizers` regularizer. + The bias regularizer for the convolution layers. + Defaults to `None`. + use_batch_norm: bool. Whether to use batch normalization. + Defaults to `False`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. + """ + + def __init__( + self, + min_level, + max_level, + num_filters=256, + activation="relu", + kernel_initializer="VarianceScaling", + bias_initializer="zeros", + batch_norm_momentum=0.99, + batch_norm_epsilon=0.001, + kernel_regularizer=None, + bias_regularizer=None, + use_batch_norm=False, + **kwargs, + ): + super().__init__(**kwargs) + if min_level > max_level: + raise ValueError( + f"Minimum level ({min_level}) must be less than or equal to " + f"maximum level ({max_level})." + ) + self.min_level = min_level + self.max_level = max_level + self.num_filters = num_filters + self.activation = keras.activations.get(activation) + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_epsilon = batch_norm_epsilon + self.use_batch_norm = use_batch_norm + if kernel_regularizer is not None: + self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) + else: + self.kernel_regularizer = None + if bias_regularizer is not None: + self.bias_regularizer = keras.regularizers.get(bias_regularizer) + else: + self.bias_regularizer = None + self.data_format = keras.backend.image_data_format() + self.batch_norm_axis = -1 if self.data_format == "channels_last" else 1 + + def build(self, input_shapes): + input_shapes = { + ( + input_name.split("_")[0] + if "shape" in input_name + else input_name + ): input_shapes[input_name] + for input_name in input_shapes + } + input_levels = [int(level[1]) for level in input_shapes] + backbone_max_level = min(max(input_levels), self.max_level) + + # Build lateral layers + self.lateral_conv_layers = {} + for i in range(self.min_level, backbone_max_level + 1): + level = f"P{i}" + self.lateral_conv_layers[level] = keras.layers.Conv2D( + filters=self.num_filters, + kernel_size=1, + padding="same", + data_format=self.data_format, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + dtype=self.dtype_policy, + name=f"lateral_conv_{level}", + ) + self.lateral_conv_layers[level].build(input_shapes[level]) + + self.lateral_batch_norm_layers = {} + if self.use_batch_norm: + for i in range(self.min_level, backbone_max_level + 1): + level = f"P{i}" + self.lateral_batch_norm_layers[level] = ( + keras.layers.BatchNormalization( + axis=self.batch_norm_axis, + momentum=self.batch_norm_epsilon, + epsilon=self.batch_norm_epsilon, + name=f"lateral_norm_{level}", + ) + ) + self.lateral_batch_norm_layers[level].build( + (None, None, None, 256) + if self.data_format == "channels_last" + else (None, 256, None, None) + ) + + # Build output layers + self.output_conv_layers = {} + for i in range(self.min_level, backbone_max_level + 1): + level = f"P{i}" + self.output_conv_layers[level] = keras.layers.Conv2D( + filters=self.num_filters, + kernel_size=3, + padding="same", + data_format=self.data_format, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + dtype=self.dtype_policy, + name=f"output_conv_{level}", + ) + self.output_conv_layers[level].build( + (None, None, None, 256) + if self.data_format == "channels_last" + else (None, 256, None, None) + ) + + # Build coarser layers + for i in range(backbone_max_level + 1, self.max_level + 1): + level = f"P{i}" + self.output_conv_layers[level] = keras.layers.Conv2D( + filters=self.num_filters, + strides=2, + kernel_size=3, + padding="same", + data_format=self.data_format, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + dtype=self.dtype_policy, + name=f"coarser_{level}", + ) + self.output_conv_layers[level].build( + (None, None, None, 256) + if self.data_format == "channels_last" + else (None, 256, None, None) + ) + + # Build batch norm layers + self.output_batch_norms = {} + if self.use_batch_norm: + for i in range(self.min_level, self.max_level + 1): + level = f"P{i}" + self.output_batch_norms[level] = ( + keras.layers.BatchNormalization( + axis=self.batch_norm_axis, + momentum=self.batch_norm_epsilon, + epsilon=self.batch_norm_epsilon, + name=f"output_norm_{level}", + ) + ) + self.output_batch_norms[level].build( + (None, None, None, 256) + if self.data_format == "channels_last" + else (None, 256, None, None) + ) + + # The same upsampling layer is used for all levels + self.top_down_op = keras.layers.UpSampling2D( + size=2, + data_format=self.data_format, + dtype=self.dtype_policy, + name="upsampling", + ) + # The same merge layer is used for all levels + self.merge_op = keras.layers.Add( + dtype=self.dtype_policy, name="merge_op" + ) + + self.built = True + + def call(self, inputs): + """ + Inputs: + The input to the model is expected to be an `Dict[Tensors]`, + containing the feature maps on top of which the FPN + will be added. + + Outputs: + A dictionary of feature maps and added coarser levels based + on minimum and maximum levels provided to the layer. + """ + + output_features = {} + + # Get the backbone max level + input_levels = [int(level[1]) for level in inputs] + backbone_max_level = min(max(input_levels), self.max_level) + + for i in range(backbone_max_level, self.min_level - 1, -1): + level = f"P{i}" + output = self.lateral_conv_layers[level](inputs[level]) + if i < backbone_max_level: + # for the top most output, it doesn't need to merge with any + # upper stream outputs + upstream_output = self.top_down_op(output_features[f"P{i+1}"]) + output = self.merge_op([output, upstream_output]) + output_features[level] = ( + self.lateral_batch_norm_layers[level](output) + if self.use_batch_norm + else output + ) + + # Post apply the output layers so that we don't leak them to the down + # stream level + for i in range(backbone_max_level, self.min_level - 1, -1): + level = f"P{i}" + output_features[level] = self.output_conv_layers[level]( + output_features[level] + ) + + for i in range(backbone_max_level + 1, self.max_level + 1): + level = f"P{i}" + feats_in = output_features[f"P{i-1}"] + if i > backbone_max_level + 1: + feats_in = self.activation(feats_in) + output_features[level] = ( + self.output_batch_norms[level]( + self.output_conv_layers[level](feats_in) + ) + if self.use_batch_norm + else self.output_conv_layers[level](feats_in) + ) + + return output_features + + def get_config(self): + config = super().get_config() + config.update( + { + "min_level": self.min_level, + "max_level": self.max_level, + "num_filters": self.num_filters, + "use_batch_norm": self.use_batch_norm, + "activation": keras.activations.serialize(self.activation), + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + "batch_norm_momentum": self.batch_norm_momentum, + "batch_norm_epsilon": self.batch_norm_epsilon, + "kernel_regularizer": ( + keras.regularizers.serialize(self.kernel_regularizer) + if self.kernel_regularizer is not None + else None + ), + "bias_regularizer": ( + keras.regularizers.serialize(self.bias_regularizer) + if self.bias_regularizer is not None + else None + ), + } + ) + + return config + + def compute_output_shape(self, input_shapes): + output_shape = {} + print(input_shapes) + input_levels = [int(level[1]) for level in input_shapes] + backbone_max_level = min(max(input_levels), self.max_level) + + for i in range(self.min_level, backbone_max_level + 1): + level = f"P{i}" + if self.data_format == "channels_last": + output_shape[level] = input_shapes[level][:-1] + (256,) + else: + output_shape[level] = ( + input_shapes[level][0], + 256, + ) + input_shapes[level][1:3] + + intermediate_shape = input_shapes[f"P{backbone_max_level}"] + intermediate_shape = ( + ( + intermediate_shape[0], + intermediate_shape[1] // 2, + intermediate_shape[2] // 2, + 256, + ) + if self.data_format == "channels_last" + else ( + intermediate_shape[0], + 256, + intermediate_shape[1] // 2, + intermediate_shape[2] // 2, + ) + ) + + for i in range(backbone_max_level + 1, self.max_level + 1): + level = f"P{i}" + output_shape[level] = intermediate_shape + intermediate_shape = ( + ( + intermediate_shape[0], + intermediate_shape[1] // 2, + intermediate_shape[2] // 2, + 256, + ) + if self.data_format == "channels_last" + else ( + intermediate_shape[0], + 256, + intermediate_shape[1] // 2, + intermediate_shape[2] // 2, + ) + ) + + return output_shape diff --git a/keras_hub/src/models/retinanet/feature_pyramid_test.py b/keras_hub/src/models/retinanet/feature_pyramid_test.py new file mode 100644 index 000000000..728233c6a --- /dev/null +++ b/keras_hub/src/models/retinanet/feature_pyramid_test.py @@ -0,0 +1,81 @@ +from absl.testing import parameterized +from keras import ops +from keras import random + +from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid +from keras_hub.src.tests.test_case import TestCase + + +class FeaturePyramidTest(TestCase): + def test_layer_behaviors(self): + self.run_layer_test( + cls=FeaturePyramid, + init_kwargs={ + "min_level": 3, + "max_level": 7, + "activation": "relu", + "batch_norm_momentum": 0.99, + "batch_norm_epsilon": 0.0001, + "kernel_initializer": "HeNormal", + "bias_initializer": "Zeros", + }, + input_data={ + "P3": random.uniform(shape=(2, 64, 64, 4)), + "P4": random.uniform(shape=(2, 32, 32, 8)), + "P5": random.uniform(shape=(2, 16, 16, 16)), + }, + expected_output_shape={ + "P3": (2, 64, 64, 256), + "P4": (2, 32, 32, 256), + "P5": (2, 16, 16, 256), + "P6": (2, 8, 8, 256), + "P7": (2, 4, 4, 256), + }, + expected_num_trainable_weights=16, + expected_num_non_trainable_weights=0, + ) + + @parameterized.named_parameters( + ( + "equal_resolutions", + 3, + 7, + {"P3": (2, 16, 16, 3), "P4": (2, 8, 8, 3), "P5": (2, 4, 4, 3)}, + ), + ( + "different_resolutions", + 2, + 6, + { + "P2": (2, 64, 128, 4), + "P3": (2, 32, 64, 8), + "P4": (2, 16, 32, 16), + "P5": (2, 8, 16, 32), + }, + ), + ) + def test_layer_output_shapes(self, min_level, max_level, input_shapes): + layer = FeaturePyramid(min_level=min_level, max_level=max_level) + + inputs = { + level: ops.ones(input_shapes[level]) for level in input_shapes + } + if layer.data_format == "channels_first": + inputs = { + level: ops.transpose(inputs[level], (0, 3, 1, 2)) + for level in inputs + } + + output = layer(inputs) + + for level in inputs: + self.assertEqual( + output[level].shape, + ( + (input_shapes[level][0],) + + (layer.num_filters,) + + input_shapes[level][1:3] + if layer.data_format == "channels_first" + else input_shapes[level][:-1] + (layer.num_filters,) + ), + ) diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py new file mode 100644 index 000000000..a5bf475b2 --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -0,0 +1,270 @@ +import keras +from keras import ops + +from keras_hub.src.bounding_box.converters import _encode_box_to_deltas +from keras_hub.src.bounding_box.iou import compute_iou +from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator +from keras_hub.src.models.retinanet.box_matcher import BoxMatcher +from keras_hub.src.utils import tensor_utils + + +class RetinaNetLabelEncoder(keras.layers.Layer): + """Transforms the raw labels into targets for training. + + RetinaNet is a single-stage object detection network that uses a feature + pyramid network and focal loss. This class is crucial for preparing the + ground truth data to match the network's anchor-based detection approach. + + This class generates targets for a batch of samples which consists of input + images, bounding boxes for the objects present, and their class ids. It + matches ground truth boxes to anchor boxes based on IoU (Intersection over + Union) and encodes the box coordinates as offsets from the anchors. + + Targets are always represented in 'center_yxwh' format for numerical + consistency during training, regardless of the input format. + + Args: + bounding_box_format: str. The format of bounding boxes of input dataset. + Refer TODO: Add link to Keras Core Docs. + min_level: int. Minimum level of the output feature pyramid. + max_level: int. Maximum level of the output feature pyramid. + num_scales: int. Number of intermediate scales added on each level. + For example, num_scales=2 adds one additional intermediate anchor + scale [2^0, 2^0.5] on each level. + aspect_ratios: List[float]. Aspect ratios of anchors added on + each level. Each number indicates the ratio of width to height. + anchor_size: float. Scale of size of the base anchor relative to the + feature stride 2^level. + positive_threshold: float. the threshold to set an anchor to positive + match to gt box. Values above it are positive matches. + Defaults to `0.5` + negative_threshold: float. the threshold to set an anchor to negative + match to gt box. Values below it are negative matches. + Defaults to `0.4` + box_variance: List[float]. The scaling factors used to scale the + bounding box targets. + Defaults to `[0.1, 0.1, 0.2, 0.2]`. + background_class: int. The class ID used for the background class, + Defaults to `-1`. + ignore_class: int. The class ID used for the ignore class, + Defaults to `-2`. + box_matcher_match_values: List[int]. Representing + matched results (e.g. positive or negative or ignored match). + `len(match_values)` must equal to `len(thresholds) + 1`. + Defaults to `[-1, -2, -1]`. + box_matcher_force_match_for_each_col: bool. If True, each column + (ground truth box) will be matched to at least one row (anchor box). + This means some columns may be matched to multiple rows while others + may not be matched to any. + Defaults to `False`. + + Note: `tf.RaggedTensor` are not supported. + """ + + def __init__( + self, + bounding_box_format, + min_level, + max_level, + num_scales, + aspect_ratios, + anchor_size, + positive_threshold=0.5, + negative_threshold=0.4, + box_variance=[0.1, 0.1, 0.2, 0.2], + background_class=-1.0, + ignore_class=-2.0, + box_matcher_match_values=[-1, -2, 1], + box_matcher_force_match_for_each_col=False, + **kwargs, + ): + super().__init__(**kwargs) + self.bounding_box_format = bounding_box_format + self.min_level = min_level + self.max_level = max_level + self.num_scales = num_scales + self.aspect_ratios = aspect_ratios + self.anchor_size = anchor_size + self.positive_threshold = positive_threshold + self.box_variance = box_variance + self.negative_threshold = negative_threshold + self.background_class = background_class + self.ignore_class = ignore_class + + self.anchor_generator = AnchorGenerator( + bounding_box_format=bounding_box_format, + min_level=min_level, + max_level=max_level, + num_scales=num_scales, + aspect_ratios=aspect_ratios, + anchor_size=anchor_size, + ) + + self.box_matcher = BoxMatcher( + thresholds=[negative_threshold, positive_threshold], + match_values=box_matcher_match_values, + force_match_for_each_col=box_matcher_force_match_for_each_col, + ) + + def build(self, images_shape, gt_boxes_shape, gt_classes_shape): + self.built = True + + def call(self, images, gt_boxes, gt_classes): + """Creates box and classification targets for a batch. + + Args: + images: A Tensor. The input images argument should be + of shape `[B, H, W, C]` or `[B, C, H, W]`. + gt_boxes: A Tensor with shape of `[B, num_boxes, 4]`. + gt_labels: A Tensor with shape of `[B, num_boxes, num_classes]` + + Returns: + box_targets: A Tensor of shape `[batch_size, num_anchors, 4]` + containing the encoded box targets. + class_targets: A Tensor of shape `[batch_size, num_anchors, 1]` + containing the class targets for each anchor. + """ + + images_shape = ops.shape(images) + if len(images_shape) != 4: + raise ValueError( + "`RetinaNetLabelEncoder`'s `call()` method does not " + "support unbatched inputs for the `images` argument. " + f"Received `shape(images)={images_shape}`." + ) + image_shape = images_shape[1:] + + if len(ops.shape(gt_classes)) == 2: + gt_classes = ops.expand_dims(gt_classes, axis=-1) + + anchor_boxes = self.anchor_generator(images) + anchor_boxes = ops.concatenate(list(anchor_boxes.values()), axis=0) + + box_targets, class_targets = self._encode_sample( + gt_boxes, gt_classes, anchor_boxes, image_shape + ) + box_targets = ops.reshape( + box_targets, (-1, ops.shape(box_targets)[1], 4) + ) + return box_targets, class_targets + + def _encode_sample(self, gt_boxes, gt_classes, anchor_boxes, image_shape): + """Creates box and classification targets for a batched sample. + + Matches ground truth boxes to anchor boxes based on IOU. + 1. Calculates the pairwise IOU for the M `anchor_boxes` and N `gt_boxes` + to get a `(M, N)` shaped matrix. + 2. The ground truth box with the maximum IOU in each row is assigned to + the anchor box provided the IOU is greater than `match_iou`. + 3. If the maximum IOU in a row is less than `ignore_iou`, the anchor + box is assigned with the background class. + 4. The remaining anchor boxes that do not have any class assigned are + ignored during training. + + Args: + gt_boxes: A Tensor of shape `[B, num_boxes, 4]`. Should be in + `bounding_box_format`. + gt_classes: A Tensor fo shape `[B, num_boxes, num_classes, 1]`. + anchor_boxes: A Tensor with the shape `[total_anchors, 4]` + representing all the anchor boxes for a given input image shape, + where each anchor box is of the format `[x, y, width, height]`. + image_shape: Tuple indicating the image shape `[H, W, C]`. + + Returns: + Encoded boudning boxes in the format of `center_yxwh` and + corresponding labels for each encoded bounding box. + """ + + iou_matrix = compute_iou( + anchor_boxes, + gt_boxes, + bounding_box_format=self.bounding_box_format, + image_shape=image_shape, + ) + + matched_gt_idx, matched_vals = self.box_matcher(iou_matrix) + matched_vals = ops.expand_dims(matched_vals, axis=-1) + positive_mask = ops.cast(ops.equal(matched_vals, 1), self.dtype) + ignore_mask = ops.cast(ops.equal(matched_vals, -2), self.dtype) + + matched_gt_boxes = tensor_utils.target_gather(gt_boxes, matched_gt_idx) + + matched_gt_boxes = ops.reshape( + matched_gt_boxes, (-1, ops.shape(matched_gt_boxes)[1], 4) + ) + + box_target = _encode_box_to_deltas( + anchors=anchor_boxes, + boxes=matched_gt_boxes, + anchor_format=self.bounding_box_format, + box_format=self.bounding_box_format, + variance=self.box_variance, + image_shape=image_shape, + ) + + matched_gt_cls_ids = tensor_utils.target_gather( + gt_classes, matched_gt_idx + ) + cls_target = ops.where( + ops.not_equal(positive_mask, 1.0), + self.background_class, + matched_gt_cls_ids, + ) + cls_target = ops.where( + ops.equal(ignore_mask, 1.0), self.ignore_class, cls_target + ) + label = ops.concatenate( + [box_target, ops.cast(cls_target, box_target.dtype)], axis=-1 + ) + + # In the case that a box in the corner of an image matches with an all + # -1 box that is outside the image, we should assign the box to the + # ignore class. There are rare cases where a -1 box can be matched, + # resulting in a NaN during training. The unit test passing all -1s to + # the label encoder ensures that we properly handle this edge-case. + label = ops.where( + ops.expand_dims(ops.any(ops.isnan(label), axis=-1), axis=-1), + self.ignore_class, + label, + ) + + return label[:, :, :4], label[:, :, 4] + + def get_config(self): + config = super().get_config() + config.update( + { + "bounding_box_format": self.bounding_box_format, + "min_level": self.min_level, + "max_level": self.max_level, + "num_scales": self.num_scales, + "aspect_ratios": self.aspect_ratios, + "anchor_size": self.anchor_size, + "positive_threshold": self.positive_threshold, + "box_variance": self.box_variance, + "negative_threshold": self.negative_threshold, + "background_class": self.background_class, + "ignore_class": self.ignore_class, + } + ) + return config + + def compute_output_shape( + self, images_shape, gt_boxes_shape, gt_classes_shape + ): + min_level = self.anchor_generator.min_level + max_level = self.anchor_generator.max_level + batch_size, image_H, image_W = images_shape[:-1] + + total_num_anchors = 0 + for i in range(min_level, max_level + 1): + total_num_anchors += ( + (image_H // 2 ** (i)) + * (image_W // 2 ** (i)) + * self.anchor_generator.anchors_per_location + ) + + return (batch_size, total_num_anchors, 4), ( + batch_size, + total_num_anchors, + ) diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py b/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py new file mode 100644 index 000000000..de329685a --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py @@ -0,0 +1,85 @@ +import numpy as np +from keras import ops + +from keras_hub.src.models.retinanet.retinanet_label_encoder import ( + RetinaNetLabelEncoder, +) +from keras_hub.src.tests.test_case import TestCase + + +class RetinaNetLabelEncoderTest(TestCase): + def test_layer_behaviors(self): + images_shape = (8, 128, 128, 3) + boxes_shape = (8, 10, 4) + classes_shape = (8, 10) + self.run_layer_test( + cls=RetinaNetLabelEncoder, + init_kwargs={ + "bounding_box_format": "xyxy", + "min_level": 3, + "max_level": 7, + "num_scales": 3, + "aspect_ratios": [0.5, 1.0, 2.0], + "anchor_size": 8, + }, + input_data={ + "images": np.random.uniform(size=images_shape), + "gt_boxes": np.random.uniform( + size=boxes_shape, low=0.0, high=1.0 + ), + "gt_classes": np.random.uniform( + size=classes_shape, low=0, high=5 + ), + }, + expected_output_shape=((8, 3069, 4), (8, 3069)), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + run_training_check=False, + run_precision_checks=False, + ) + + def test_label_encoder_output_shapes(self): + images_shape = (8, 128, 128, 3) + boxes_shape = (8, 10, 4) + classes_shape = (8, 10) + + images = np.random.uniform(size=images_shape) + boxes = np.random.uniform(size=boxes_shape, low=0.0, high=1.0) + classes = np.random.uniform(size=classes_shape, low=0, high=5) + + encoder = RetinaNetLabelEncoder( + bounding_box_format="xyxy", + min_level=3, + max_level=7, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=8, + ) + + box_targets, class_targets = encoder(images, boxes, classes) + + self.assertEqual(box_targets.shape, (8, 3069, 4)) + self.assertEqual(class_targets.shape, (8, 3069)) + + def test_all_negative_1(self): + images_shape = (8, 128, 128, 3) + boxes_shape = (8, 10, 4) + classes_shape = (8, 10) + + images = np.random.uniform(size=images_shape) + boxes = -np.ones(shape=boxes_shape, dtype="float32") + classes = -np.ones(shape=classes_shape, dtype="float32") + + encoder = RetinaNetLabelEncoder( + bounding_box_format="xyxy", + min_level=3, + max_level=7, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=8, + ) + + box_targets, class_targets = encoder(images, boxes, classes) + + self.assertFalse(ops.any(ops.isnan(box_targets))) + self.assertFalse(ops.any(ops.isnan(class_targets))) diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index 310d8e8b4..6d06c7266 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -13,6 +13,7 @@ from keras_hub.src.layers.modeling.reversible_embedding import ( ReversibleEmbedding, ) +from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid from keras_hub.src.tokenizers.tokenizer import Tokenizer from keras_hub.src.utils.keras_utils import has_quantization_support from keras_hub.src.utils.tensor_utils import is_float_dtype @@ -127,7 +128,10 @@ def __init__(self, layer): def call(self, x): if isinstance(x, dict): - return self.layer(**x) + if isinstance(layer, FeaturePyramid): + return self.layer(x) + else: + return self.layer(**x) else: return self.layer(x) @@ -147,7 +151,10 @@ def call(self, x): layer = cls(**init_kwargs) if isinstance(input_data, dict): shapes = {k + "_shape": v.shape for k, v in input_data.items()} - layer.build(**shapes) + if isinstance(layer, FeaturePyramid): + layer.build(shapes) + else: + layer.build(**shapes) else: layer.build(input_data.shape) run_build_asserts(layer) @@ -158,7 +165,10 @@ def call(self, x): ) layer = cls(**init_kwargs) if isinstance(keras_tensor_inputs, dict): - keras_tensor_outputs = layer(**keras_tensor_inputs) + if isinstance(layer, FeaturePyramid): + keras_tensor_outputs = layer(keras_tensor_inputs) + else: + keras_tensor_outputs = layer(**keras_tensor_inputs) else: keras_tensor_outputs = layer(keras_tensor_inputs) run_build_asserts(layer) @@ -167,7 +177,10 @@ def call(self, x): # Eager call test and compiled training test. layer = cls(**init_kwargs) if isinstance(input_data, dict): - output_data = layer(**input_data) + if isinstance(layer, FeaturePyramid): + output_data = layer(input_data) + else: + output_data = layer(**input_data) else: output_data = layer(input_data) run_output_asserts(layer, output_data, eager=True) @@ -305,8 +318,12 @@ def run_precision_test(self, cls, init_kwargs, input_data): output_data = layer(input_data) output_spec = layer.compute_output_spec(input_data) elif isinstance(input_data, dict): - output_data = layer(**input_data) - output_spec = layer.compute_output_spec(**input_data) + if isinstance(layer, FeaturePyramid): + output_data = layer(input_data) + output_spec = layer.compute_output_spec(input_data) + else: + output_data = layer(**input_data) + output_spec = layer.compute_output_spec(**input_data) else: output_data = layer(input_data) output_spec = layer.compute_output_spec(input_data) diff --git a/keras_hub/src/utils/tensor_utils.py b/keras_hub/src/utils/tensor_utils.py index 3d18aae99..36177c894 100644 --- a/keras_hub/src/utils/tensor_utils.py +++ b/keras_hub/src/utils/tensor_utils.py @@ -308,3 +308,109 @@ def any_equal(inputs, values, padding_mask): output = ops.logical_or(output, value_equality) return ops.logical_and(output, padding_mask) + + +def target_gather( + targets, + indices, + mask=None, + mask_val=0.0, +): + """A utility function wrapping `ops.take`, which deals with: + 1) both batched and unbatched `targets`. + 2) when unbatched `targets` have empty rows, the result will be filled + with `mask_val`. + 3) target masking. + + Args: + targets: `[N, ...]` or `[batch_size, N, ...]` Tensor representing + targets such as boxes, keypoints, etc. + indices: [M] or [batch_size, M] int32 Tensor representing indices within + `targets` to gather. + mask: `[M, ...]` or `[batch_size, M, ...]` boolean Tensor + representing the masking for each target. `True` means the + corresponding entity should be masked to `mask_val`, `False` + means the corresponding entity should be the target value. + Defaults to `None`. + mask_val: float. representing the masking value if `mask` is True + on the entity. + Defaults to `0.0` + + Returns: + targets: `[M, ...]` or `[batch_size, M, ...]` Tensor representing + selected targets. + + Raise: + ValueError: If `targets` is higher than rank 3. + """ + targets_shape = list(targets.shape) + if len(targets_shape) > 3: + raise ValueError( + f"`target_gather` does not support `targets` with rank " + f"larger than 3, got {len(targets.shape)}" + ) + + def gather_unbatched(labels, match_indices, mask, mask_val): + """Gather based on unbatched labels and boxes.""" + num_gt_boxes = labels.shape[0] + + def assign_when_rows_empty(): + if len(labels.shape) > 1: + mask_shape = [match_indices.shape[0], labels.shape[-1]] + else: + mask_shape = [match_indices.shape[0]] + return ops.cast(mask_val, labels.dtype) * ops.ones( + mask_shape, dtype=labels.dtype + ) + + def assign_when_rows_not_empty(): + targets = ops.take(labels, match_indices, axis=0) + if mask is None: + return targets + else: + masked_targets = ops.cast( + mask_val, labels.dtype + ) * ops.ones_like(mask, dtype=labels.dtype) + return ops.where(mask, masked_targets, targets) + + if num_gt_boxes > 0: + return assign_when_rows_not_empty() + else: + return assign_when_rows_empty() + + def _gather_batched(labels, match_indices, mask, mask_val): + """Gather based on batched labels.""" + batch_size = labels.shape[0] + if batch_size == 1: + if mask is not None: + result = gather_unbatched( + ops.squeeze(labels, axis=0), + ops.squeeze(match_indices, axis=0), + ops.squeeze(mask, axis=0), + mask_val, + ) + else: + result = gather_unbatched( + ops.squeeze(labels, axis=0), + ops.squeeze(match_indices, axis=0), + None, + mask_val, + ) + return ops.expand_dims(result, axis=0) + else: + targets = ops.take_along_axis( + labels, ops.expand_dims(match_indices, axis=-1), axis=1 + ) + + if mask is None: + return targets + else: + masked_targets = ops.cast( + mask_val, labels.dtype + ) * ops.ones_like(mask, dtype=labels.dtype) + return ops.where(mask, masked_targets, targets) + + if len(targets_shape) <= 2: + return gather_unbatched(targets, indices, mask, mask_val) + elif len(targets_shape) == 3: + return _gather_batched(targets, indices, mask, mask_val) diff --git a/keras_hub/src/utils/tensor_utils_test.py b/keras_hub/src/utils/tensor_utils_test.py index 42d04a029..0b6ef1f34 100644 --- a/keras_hub/src/utils/tensor_utils_test.py +++ b/keras_hub/src/utils/tensor_utils_test.py @@ -10,6 +10,7 @@ from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch from keras_hub.src.utils.tensor_utils import is_tensor_type from keras_hub.src.utils.tensor_utils import preprocessing_function +from keras_hub.src.utils.tensor_utils import target_gather from keras_hub.src.utils.tensor_utils import tensor_to_list @@ -202,3 +203,104 @@ def test_input_shaped_values(self): result = any_equal(inputs, values, padding_mask) result = ops.convert_to_numpy(result) self.assertAllEqual(result, expected_output) + + +class TargetGatherTest(TestCase): + def test_target_gather_boxes_batched(self): + target_boxes = np.array( + [[0, 0, 5, 5], [0, 5, 5, 10], [5, 0, 10, 5], [5, 5, 10, 10]] + ) + target_boxes = ops.expand_dims(target_boxes, axis=0) + indices = np.array([[0, 2]], dtype="int32") + expected_boxes = np.array([[0, 0, 5, 5], [5, 0, 10, 5]]) + expected_boxes = ops.expand_dims(expected_boxes, axis=0) + res = target_gather(target_boxes, indices) + self.assertAllClose(expected_boxes, res) + + def test_target_gather_boxes_unbatched(self): + target_boxes = np.array( + [[0, 0, 5, 5], [0, 5, 5, 10], [5, 0, 10, 5], [5, 5, 10, 10]], + "int32", + ) + indices = np.array([0, 2], dtype="int32") + expected_boxes = np.array([[0, 0, 5, 5], [5, 0, 10, 5]]) + res = target_gather(target_boxes, indices) + self.assertAllClose(expected_boxes, res) + + def test_target_gather_classes_batched(self): + target_classes = np.array([[1, 2, 3, 4]]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([[0, 2]], dtype="int32") + expected_classes = np.array([[1, 3]]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_unbatched(self): + target_classes = np.array([1, 2, 3, 4]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([0, 2], dtype="int32") + expected_classes = np.array([1, 3]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_batched_with_mask(self): + target_classes = np.array([[1, 2, 3, 4]]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([[0, 2]], dtype="int32") + masks = np.array(([[False, True]])) + masks = ops.expand_dims(masks, axis=-1) + # the second element is masked + expected_classes = np.array([[1, 0]]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices, masks) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_batched_with_mask_val(self): + target_classes = np.array([[1, 2, 3, 4]]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([[0, 2]], dtype="int32") + masks = np.array(([[False, True]])) + masks = ops.expand_dims(masks, axis=-1) + # the second element is masked + expected_classes = np.array([[1, -1]]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices, masks, -1) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_unbatched_with_mask(self): + target_classes = np.array([1, 2, 3, 4]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([0, 2], dtype="int32") + masks = np.array([False, True]) + masks = ops.expand_dims(masks, axis=-1) + expected_classes = np.array([1, 0]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices, masks) + self.assertAllClose(expected_classes, res) + + def test_target_gather_with_empty_targets(self): + target_classes = np.array([]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([0, 2], dtype="int32") + # return all 0s since input is empty + expected_classes = np.array([0, 0]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_multi_batch(self): + target_classes = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) + target_classes = ops.expand_dims(target_classes, axis=-1) + indices = np.array([[0, 2], [1, 3]], dtype="int32") + expected_classes = np.array([[1, 3], [6, 8]]) + expected_classes = ops.expand_dims(expected_classes, axis=-1) + res = target_gather(target_classes, indices) + self.assertAllClose(expected_classes, res) + + def test_target_gather_invalid_rank(self): + targets = np.random.normal(size=[32, 2, 2, 2]) + indices = np.array([0, 1], dtype="int32") + with self.assertRaisesRegex(ValueError, "larger than 3"): + _ = target_gather(targets, indices)