Skip to content

Commit

Permalink
add sam preprocessor and checkpoint conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
divyashreepathihalli committed Sep 28, 2024
1 parent 8b34476 commit a61a595
Show file tree
Hide file tree
Showing 7 changed files with 498 additions and 241 deletions.
86 changes: 86 additions & 0 deletions keras_hub/src/models/image_segmenter_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.preprocessor import Preprocessor
from keras_hub.src.utils.tensor_utils import preprocessing_function


@keras_hub_export("keras_hub.models.ImageSegmenterPreprocessor")
class ImageSegmenterPreprocessor(Preprocessor):
"""Base class for image segmentation preprocessing layers.
`ImageSegmenterPreprocessor` wraps a
`keras_hub.layers.ImageConverter` to create a preprocessing layer for
image segmentation tasks. It is intended to be paired with a
`keras_hub.models.ImageSegmenter` task.
All `ImageSegmenterPreprocessor` instances take three inputs: `x`, `y`, and
`sample_weight`.
- `x`: The first input, should always be included. It can be an image or
a batch of images.
- `y`: (Optional) Usually the segmentation mask(s), will be passed through
unaltered.
- `sample_weight`: (Optional) Will be passed through unaltered.
The layer will output either `x`, an `(x, y)` tuple if labels were provided,
or an `(x, y, sample_weight)` tuple if labels and sample weight were
provided. `x` will be the input images after all model preprocessing has
been applied.
All `ImageSegmenterPreprocessor` tasks include a `from_preset()`
constructor which can be used to load a pre-trained config and vocabularies.
You can call the `from_preset()` constructor directly on this base class, in
which case the correct class for your model will be automatically
instantiated.
Examples.
```python
preprocessor = keras_hub.models.ImageSegmenterPreprocessor.from_preset(
"deeplabv3_resnet50",
)
# Resize a single image for the model.
x = np.ones((512, 512, 3))
x = preprocessor(x)
# Resize an image and its mask.
x, y = np.ones((512, 512, 3)), np.zeros((512, 512, 1))
x, y = preprocessor(x, y)
# Resize a batch of images and masks.
x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], [np.ones((512, 512, 1)), np.zeros((512, 512, 1))]
x, y = preprocessor(x, y)
# Use a `tf.data.Dataset`.
ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(2)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
```
"""

def __init__(
self,
image_converter=None,
**kwargs,
):
super().__init__(**kwargs)
self.image_converter = image_converter

@preprocessing_function
def call(self, x, y=None, sample_weight=None):
if self.image_converter:
x = self.image_converter(x)
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
23 changes: 23 additions & 0 deletions keras_hub/src/models/sam/sam_image_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.resizing_image_converter import (
ResizingImageConverter,
)
from keras_hub.src.models.sam.sam_backbone import SAMBackbone


@keras_hub_export("keras_hub.layers.SamImageConverter")
class SamImageConverter(ResizingImageConverter):
backbone_cls = SAMBackbone
6 changes: 5 additions & 1 deletion keras_hub/src/models/sam/sam_image_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.image_segmenter import ImageSegmenter
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
SamImageSegmenterPreprocessor,
)


@keras_hub_export("keras_hub.models.SAMImageSegmenter")
Expand Down Expand Up @@ -179,7 +182,7 @@ class SAMImageSegmenter(ImageSegmenter):
"""

backbone_cls = SAMBackbone
preprocessor_cls = None
preprocessor_cls = SamImageSegmenterPreprocessor

def __init__(self, backbone, preprocessor=None, **kwargs):
# The implementation has been adapted form [Segment Anything
Expand All @@ -188,6 +191,7 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
# [Detectron2](https://github.com/facebookresearch/detectron2).
# === Layers ===
self.backbone = backbone
self.preprocessor = preprocessor
# === Functional Model ===
inputs = self.backbone.input
x = self.backbone(inputs)
Expand Down
25 changes: 25 additions & 0 deletions keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.image_segmenter_preprocessor import (
ImageSegmenterPreprocessor,
)
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
from keras_hub.src.models.sam.sam_image_converter import SamImageConverter


@keras_hub_export("keras_hub.models.SamImageSegmenterPreprocessor")
class SamImageSegmenterPreprocessor(ImageSegmenterPreprocessor):
backbone_cls = SAMBackbone
image_converter_cls = SamImageConverter
47 changes: 47 additions & 0 deletions keras_hub/src/models/sam/sam_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SAM preset configurations."""

backbone_presets = {
"sam_base_sa1b": {
"metadata": {
"description": ("The base SAM model trained on the SA1B dataset."),
"params": 93735728,
"official_name": "SAMImageSegmenter",
"path": "sam",
"model_card": "https://arxiv.org/abs/2304.02643",
},
"kaggle_handle": "kaggle://kerashub/sam/keras/sam_base_sa1b/1",
},
"sam_large_sa1b": {
"metadata": {
"description": ("The large SAM model trained on the SA1B dataset."),
"params": 641090864,
"official_name": "SAMImageSegmenter",
"path": "sam",
"model_card": "https://arxiv.org/abs/2304.02643",
},
"kaggle_handle": "kaggle://kerashub/sam/keras/sam_large_sa1b/1",
},
"sam_huge_sa1b": {
"metadata": {
"description": ("The huge SAM model trained on the SA1B dataset."),
"params": 312343088,
"official_name": "SAMImageSegmenter",
"path": "sam",
"model_card": "https://arxiv.org/abs/2304.02643",
},
"kaggle_handle": "kaggle://kerashub/sam/keras/sam_huge_sa1b/1",
},
}
4 changes: 2 additions & 2 deletions keras_hub/src/models/vit_det/vit_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __init__(
"Input size must be provided if using relative "
"positional encoding."
)
self.add_decomposed_reative_pe = AddRelativePositionalEmbedding(
self.add_decomposed_relative_pe = AddRelativePositionalEmbedding(
self.input_size, self.key_dim
)

Expand Down Expand Up @@ -256,7 +256,7 @@ def call(self, x):
keys, axes=(0, 2, 1)
)
if self.use_rel_pos:
attention_map = self.add_decomposed_reative_pe(
attention_map = self.add_decomposed_relative_pe(
attention_map,
queries=queries,
query_size=(height, width),
Expand Down
Loading

0 comments on commit a61a595

Please sign in to comment.