diff --git a/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java b/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java index 0fd4f76c70b..13a5b17189e 100644 --- a/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java @@ -295,7 +295,7 @@ public List findBoundingBoxes() { /** {@inheritDoc} */ @Override - public void drawBoundingBoxes(DetectedObjects detections) { + public void drawBoundingBoxes(DetectedObjects detections, float opacity) { // Make image copy with alpha channel because original image was jpg convertIdNeeded(); @@ -321,18 +321,20 @@ public void drawBoundingBoxes(DetectedObjects detections) { k = (k + 100) % 255; } - Rectangle rectangle = box.getBounds(); - int x = (int) (rectangle.getX() * imageWidth); - int y = (int) (rectangle.getY() * imageHeight); - g.drawRect( - x, - y, - (int) (rectangle.getWidth() * imageWidth), - (int) (rectangle.getHeight() * imageHeight)); - drawText(g, className, x, y, stroke, 4); + if (!className.isEmpty()) { + Rectangle rectangle = box.getBounds(); + int x = (int) (rectangle.getX() * imageWidth); + int y = (int) (rectangle.getY() * imageHeight); + g.drawRect( + x, + y, + (int) (rectangle.getWidth() * imageWidth), + (int) (rectangle.getHeight() * imageHeight)); + drawText(g, className, x, y, stroke, 4); + } // If we have a mask instead of a plain rectangle, draw tha mask if (box instanceof Mask) { - drawMask((Mask) box); + drawMask((Mask) box, opacity); } else if (box instanceof Landmark) { drawLandmarks(box); } @@ -340,6 +342,19 @@ public void drawBoundingBoxes(DetectedObjects detections) { g.dispose(); } + /** {@inheritDoc} */ + @Override + public void drawMarks(List points, int radius) { + Graphics2D g = (Graphics2D) image.getGraphics(); + g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON); + g.setColor(new Color(246, 96, 0)); + for (Point point : points) { + int[][] star = createStar(point, radius); + g.fillPolygon(star[0], star[1], 10); + } + g.dispose(); + } + /** {@inheritDoc} */ @Override public void drawJoints(Joints joints) { @@ -421,7 +436,7 @@ private void drawText(Graphics2D g, String text, int x, int y, int stroke, int p g.drawString(text, x + padding, y + ascent); } - private void drawMask(Mask mask) { + private void drawMask(Mask mask, float ratio) { float r = RandomUtils.nextFloat(); float g = RandomUtils.nextFloat(); float b = RandomUtils.nextFloat(); @@ -445,13 +460,15 @@ private void drawMask(Mask mask) { } } float[][] probDist = mask.getProbDist(); - float max = 0; - for (float[] row : probDist) { - for (float f : row) { - max = Math.max(max, f); + if (ratio < 0 || ratio > 1) { + float max = 0; + for (float[] row : probDist) { + for (float f : row) { + max = Math.max(max, f); + } } + ratio = 0.5f / max; } - float ratio = 0.5f / max; BufferedImage maskImage = new BufferedImage( diff --git a/api/src/main/java/ai/djl/modality/cv/Image.java b/api/src/main/java/ai/djl/modality/cv/Image.java index 26e7fe4ce6f..89e78eb2863 100644 --- a/api/src/main/java/ai/djl/modality/cv/Image.java +++ b/api/src/main/java/ai/djl/modality/cv/Image.java @@ -15,6 +15,7 @@ import ai.djl.modality.cv.output.BoundingBox; import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.modality.cv.output.Joints; +import ai.djl.modality.cv.output.Point; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; @@ -125,7 +126,36 @@ default NDArray toNDArray(NDManager manager) { * * @param detections the object detection results */ - void drawBoundingBoxes(DetectedObjects detections); + default void drawBoundingBoxes(DetectedObjects detections) { + drawBoundingBoxes(detections, -1); + } + + /** + * Draws the bounding boxes on the image. + * + * @param detections the object detection results + */ + void drawBoundingBoxes(DetectedObjects detections, float opacity); + + /** + * Draws a mark on the image. + * + * @param points as list of {@code Point} + */ + default void drawMarks(List points) { + int w = getWidth(); + int h = getHeight(); + int size = Math.min(w, h) / 50; + drawMarks(points, size); + } + + /** + * Draws a mark on the image. + * + * @param points as list of {@code Point} + * @param size the radius of the star mark + */ + void drawMarks(List points, int size); /** * Draws all joints of a body on an image. @@ -142,6 +172,32 @@ default NDArray toNDArray(NDManager manager) { */ void drawImage(Image overlay, boolean resize); + /** + * Creates a star shape. + * + * @param point the coordinate + * @param radius the radius + * @return the polygon points + */ + default int[][] createStar(Point point, int radius) { + int[][] ret = new int[2][10]; + double midX = point.getX(); + double midY = point.getY(); + double[] ratio = {radius, radius * 0.38196601125}; + + double delta = Math.PI / 5; + for (int i = 0; i < 10; ++i) { + double angle = i * delta; + double r = ratio[i % 2]; + double x = Math.cos(angle) * r; + double y = Math.sin(angle) * r; + + ret[0][i] = (int) (x + midX); + ret[1][i] = (int) (y + midY); + } + return ret; + } + /** Flag indicates the color channel options for images. */ enum Flag { GRAYSCALE, diff --git a/api/src/main/java/ai/djl/modality/cv/output/Mask.java b/api/src/main/java/ai/djl/modality/cv/output/Mask.java index e9ada0710fc..622b807d1b0 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/Mask.java +++ b/api/src/main/java/ai/djl/modality/cv/output/Mask.java @@ -12,6 +12,9 @@ */ package ai.djl.modality.cv.output; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.types.Shape; + /** * A mask with a probability for each pixel within a bounding rectangle. * @@ -75,4 +78,22 @@ public float[][] getProbDist() { public boolean isFullImageMask() { return fullImageMask; } + + /** + * Converts the mask tensor to a mask array. + * + * @param array the mask NDArray + * @return the mask array + */ + public static float[][] toMask(NDArray array) { + Shape maskShape = array.getShape(); + int height = (int) maskShape.get(0); + int width = (int) maskShape.get(1); + float[] flattened = array.toFloatArray(); + float[][] mask = new float[height][width]; + for (int i = 0; i < height; i++) { + System.arraycopy(flattened, i * width, mask[i], 0, width); + } + return mask; + } } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslator.java index 458122bbc0b..02856c13f08 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslator.java @@ -95,14 +95,7 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { // Reshape mask to actual image bounding box shape. NDArray array = masks.get(i); - Shape maskShape = array.getShape(); - int maskH = (int) maskShape.get(0); - int maskW = (int) maskShape.get(1); - float[] flattened = array.toFloatArray(); - float[][] maskFloat = new float[maskH][maskW]; - for (int j = 0; j < maskH; j++) { - System.arraycopy(flattened, j * maskW, maskFloat[j], 0, maskW); - } + float[][] maskFloat = Mask.toMask(array); Mask mask = new Mask(x, y, w, h, maskFloat); retNames.add(className); diff --git a/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java new file mode 100644 index 00000000000..4299af02ba3 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java @@ -0,0 +1,188 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.output.BoundingBox; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Mask; +import ai.djl.modality.cv.output.Point; +import ai.djl.modality.cv.transform.Normalize; +import ai.djl.modality.cv.transform.Resize; +import ai.djl.modality.cv.transform.ToTensor; +import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.Pipeline; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Collections; +import java.util.List; + +/** A {@link Translator} that handles mask generation task. */ +public class Sam2Translator implements NoBatchifyTranslator { + + private static final float[] MEAN = {0.485f, 0.456f, 0.406f}; + private static final float[] STD = {0.229f, 0.224f, 0.225f}; + + private Pipeline pipeline; + + /** Constructs a {@code Sam2Translator} instance. */ + public Sam2Translator() { + pipeline = new Pipeline(); + pipeline.add(new Resize(1024, 1024)); + pipeline.add(new ToTensor()); + pipeline.add(new Normalize(MEAN, STD)); + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, Sam2Input input) throws Exception { + Image image = input.getImage(); + int width = image.getWidth(); + int height = image.getHeight(); + ctx.setAttachment("width", width); + ctx.setAttachment("height", height); + + List points = input.getPoints(); + int numPoints = points.size(); + float[] buf = input.toLocationArray(width, height); + + NDManager manager = ctx.getNDManager(); + NDArray array = image.toNDArray(manager, Image.Flag.COLOR); + array = pipeline.transform(new NDList(array)).get(0).expandDims(0); + NDArray locations = manager.create(buf, new Shape(1, numPoints, 2)); + NDArray labels = manager.create(input.getLabels()); + + return new NDList(array, locations, labels); + } + + /** {@inheritDoc} */ + @Override + public DetectedObjects processOutput(TranslatorContext ctx, NDList list) throws Exception { + NDArray logits = list.get(0); + NDArray scores = list.get(1).squeeze(0); + long best = scores.argMax().getLong(); + + int width = (Integer) ctx.getAttachment("width"); + int height = (Integer) ctx.getAttachment("height"); + + long[] size = {height, width}; + int mode = Image.Interpolation.BILINEAR.ordinal(); + logits = logits.getNDArrayInternal().interpolation(size, mode, false); + NDArray masks = logits.gt(0f).squeeze(0); + + float[][] dist = Mask.toMask(masks.get(best).toType(DataType.FLOAT32, true)); + Mask mask = new Mask(0, 0, width, height, dist, true); + double probability = scores.getFloat(best); + + List classes = Collections.singletonList(""); + List probabilities = Collections.singletonList(probability); + List boxes = Collections.singletonList(mask); + + return new DetectedObjects(classes, probabilities, boxes); + } + + /** A class represents the segment anything input. */ + public static final class Sam2Input { + + private Image image; + private List points; + private List labels; + + /** + * Constructs a {@code Sam2Input} instance. + * + * @param image the image + * @param points the locations on the image + * @param labels the labels for the locations (0: background, 1: foreground) + */ + public Sam2Input(Image image, List points, List labels) { + this.image = image; + this.points = points; + this.labels = labels; + } + + /** + * Returns the image. + * + * @return the image + */ + public Image getImage() { + return image; + } + + /** + * Returns the locations. + * + * @return the locations + */ + public List getPoints() { + return points; + } + + float[] toLocationArray(int width, int height) { + float[] ret = new float[points.size() * 2]; + int i = 0; + for (Point point : points) { + ret[i++] = (float) point.getX() / width * 1024; + ret[i++] = (float) point.getY() / height * 1024; + } + return ret; + } + + int[][] getLabels() { + return new int[][] {labels.stream().mapToInt(Integer::intValue).toArray()}; + } + + /** + * Creates a new {@code Sam2Input} instance with the image and a location. + * + * @param url the image url + * @param x the X of the location + * @param y the Y of the location + * @return a new {@code Sam2Input} instance + * @throws IOException if failed to read image + */ + public static Sam2Input newInstance(String url, int x, int y) throws IOException { + Image image = ImageFactory.getInstance().fromUrl(url); + List points = Collections.singletonList(new Point(x, y)); + List labels = Collections.singletonList(1); + return new Sam2Input(image, points, labels); + } + + /** + * Creates a new {@code Sam2Input} instance with the image and a location. + * + * @param path the image file path + * @param x the X of the location + * @param y the Y of the location + * @return a new {@code Sam2Input} instance + * @throws IOException if failed to read image + */ + public static Sam2Input newInstance(Path path, int x, int y) throws IOException { + Image image = ImageFactory.getInstance().fromFile(path); + List points = Collections.singletonList(new Point(x, y)); + List labels = Collections.singletonList(1); + return new Sam2Input(image, points, labels); + } + } +} diff --git a/api/src/main/java/ai/djl/modality/cv/translator/Sam2TranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/Sam2TranslatorFactory.java new file mode 100644 index 00000000000..82fd8c6f6bb --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/Sam2TranslatorFactory.java @@ -0,0 +1,57 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.CategoryMask; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import ai.djl.util.Pair; + +import java.io.Serializable; +import java.lang.reflect.Type; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** A {@link TranslatorFactory} that creates a {@link Sam2Translator} instance. */ +public class Sam2TranslatorFactory implements TranslatorFactory, Serializable { + + private static final long serialVersionUID = 1L; + + private static final Set> SUPPORTED_TYPES = new HashSet<>(); + + static { + SUPPORTED_TYPES.add(new Pair<>(Sam2Input.class, DetectedObjects.class)); + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public Translator newInstance( + Class input, Class output, Model model, Map arguments) { + if (input == Image.class && output == CategoryMask.class) { + return (Translator) new Sam2Translator(); + } + throw new IllegalArgumentException("Unsupported input/output types."); + } + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return SUPPORTED_TYPES; + } +} diff --git a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java index 07e56a5ca04..b25852ad5ab 100644 --- a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java +++ b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java @@ -453,6 +453,8 @@ default NDArray toTensor() { } } + NDArray interpolation(long[] size, int mode, boolean alignCorners); + NDArray resize(int width, int height, int interpolation); default NDArray crop(int x, int y, int width, int height) { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java index c7efd80eba3..00a7a0f254b 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java @@ -924,6 +924,12 @@ public NDArray toTensor() { return getManager().invoke("_npx__image_to_tensor", array, null); } + /** {@inheritDoc} */ + @Override + public NDArray interpolation(long[] size, int mode, boolean alignCorners) { + throw new UnsupportedOperationException("Not implemented"); + } + /** {@inheritDoc} */ @Override public NDArray resize(int width, int height, int interpolation) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java index e9eb8b2b771..9fe5809fb70 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java @@ -565,6 +565,13 @@ public NDList lstm( batchFirst); } + /** {@inheritDoc} */ + @Override + public NDArray interpolation(long[] size, int mode, boolean alignCorners) { + return JniUtils.interpolate( + array.getManager().from(array), size, getInterpolationMode(mode), false); + } + /** {@inheritDoc} */ @Override public PtNDArray resize(int width, int height, int interpolation) { diff --git a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java index 811d6870635..0ed61eff528 100644 --- a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java +++ b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java @@ -43,6 +43,8 @@ public class PtModelZoo extends ModelZoo { addModel( REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet18_embedding", "0.0.1")); addModel(REPOSITORY.model(CV.INSTANCE_SEGMENTATION, GROUP_ID, "yolov8n-seg", "0.0.1")); + addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "sam2-hiera-tiny", "0.0.1")); + addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "sam2-hiera-large", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov5s", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1")); diff --git a/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/sam2-hiera-large/metadata.json b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/sam2-hiera-large/metadata.json new file mode 100644 index 00000000000..bc895406055 --- /dev/null +++ b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/sam2-hiera-large/metadata.json @@ -0,0 +1,37 @@ +{ + "metadataVersion": "0.2", + "resourceType": "model", + "application": "cv/object_detection", + "groupId": "ai.djl.pytorch", + "artifactId": "sam2-hiera-large", + "name": "Mask generation", + "description": "Segment Anything in Images", + "website": "http://www.djl.ai/engines/pytorch/model-zoo", + "licenses": { + "license": { + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + } + }, + "artifacts": [ + { + "version": "0.0.1", + "snapshot": false, + "name": "sam2-hiera-large", + "arguments": { + "translatorFactory": "ai.djl.modality.cv.translator.Sam2TranslatorFactory" + }, + "options": { + "mapLocation": "true" + }, + "files": { + "model": { + "uri": "0.0.1/sam2-hiera-large.zip", + "name": "", + "sha1Hash": "5688c31f52ae086e0c17dd235f4047245dc42eb3", + "size": 834572454 + } + } + } + ] +} diff --git a/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/sam2-hiera-tiny/metadata.json b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/sam2-hiera-tiny/metadata.json new file mode 100644 index 00000000000..71867e4548c --- /dev/null +++ b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/object_detection/ai/djl/pytorch/sam2-hiera-tiny/metadata.json @@ -0,0 +1,37 @@ +{ + "metadataVersion": "0.2", + "resourceType": "model", + "application": "cv/object_detection", + "groupId": "ai.djl.pytorch", + "artifactId": "sam2-hiera-tiny", + "name": "Mask generation", + "description": "Segment Anything in Images", + "website": "http://www.djl.ai/engines/pytorch/model-zoo", + "licenses": { + "license": { + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + } + }, + "artifacts": [ + { + "version": "0.0.1", + "snapshot": false, + "name": "sam2-hiera-tiny", + "arguments": { + "translatorFactory": "ai.djl.modality.cv.translator.Sam2TranslatorFactory" + }, + "options": { + "mapLocation": "true" + }, + "files": { + "model": { + "uri": "0.0.1/sam2-hiera-tiny.zip", + "name": "", + "sha1Hash": "c1eb858f0e8d53c7ec7c94434cd39b69d61db449", + "size": 145062696 + } + } + } + ] +} diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java index 5160a6b1c79..5d75f9110dd 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java @@ -504,6 +504,12 @@ public NDArray toTensor() { } } + /** {@inheritDoc} */ + @Override + public NDArray interpolation(long[] size, int mode, boolean alignCorners) { + throw new UnsupportedOperationException("Not implemented"); + } + /** {@inheritDoc} */ @Override public NDArray resize(int width, int height, int interpolation) { diff --git a/examples/src/main/java/ai/djl/examples/inference/cv/SegmentAnything2.java b/examples/src/main/java/ai/djl/examples/inference/cv/SegmentAnything2.java new file mode 100644 index 00000000000..954880cc983 --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/cv/SegmentAnything2.java @@ -0,0 +1,79 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference.cv; + +import ai.djl.ModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.translator.Sam2Translator; +import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +public final class SegmentAnything2 { + + private static final Logger logger = LoggerFactory.getLogger(SegmentAnything2.class); + + private SegmentAnything2() {} + + public static void main(String[] args) throws IOException, ModelException, TranslateException { + DetectedObjects detection = predict(); + logger.info("{}", detection); + } + + public static DetectedObjects predict() throws IOException, ModelException, TranslateException { + String url = + "https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg"; + Sam2Input input = Sam2Input.newInstance(url, 500, 375); + + Criteria criteria = + Criteria.builder() + .setTypes(Sam2Input.class, DetectedObjects.class) + .optModelUrls("djl://ai.djl.pytorch/sam2-hiera-tiny") + .optEngine("PyTorch") + .optTranslator(new Sam2Translator()) + .optProgress(new ProgressBar()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + DetectedObjects detection = predictor.predict(input); + showMask(input, detection); + return detection; + } + } + + private static void showMask(Sam2Input input, DetectedObjects detection) throws IOException { + Path outputDir = Paths.get("build/output"); + Files.createDirectories(outputDir); + + Image img = input.getImage(); + img.drawBoundingBoxes(detection, 0.8f); + img.drawMarks(input.getPoints()); + + Path imagePath = outputDir.resolve("sam2.png"); + img.save(Files.newOutputStream(imagePath), "png"); + logger.info("Segmentation result image has been saved in: {}", imagePath); + } +} diff --git a/examples/src/test/java/ai/djl/examples/inference/cv/SegmentAnything2Test.java b/examples/src/test/java/ai/djl/examples/inference/cv/SegmentAnything2Test.java new file mode 100644 index 00000000000..705832c4b23 --- /dev/null +++ b/examples/src/test/java/ai/djl/examples/inference/cv/SegmentAnything2Test.java @@ -0,0 +1,33 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference.cv; + +import ai.djl.ModelException; +import ai.djl.modality.Classifications; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.translate.TranslateException; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; + +public class SegmentAnything2Test { + + @Test + public void testInstanceSegmentation() throws ModelException, TranslateException, IOException { + DetectedObjects result = SegmentAnything2.predict(); + Classifications.Classification best = result.best(); + Assert.assertTrue(Double.compare(best.getProbability(), 0.3) > 0); + } +} diff --git a/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java b/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java index 14fd4112b50..6651a898167 100644 --- a/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java +++ b/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java @@ -47,6 +47,7 @@ import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -155,7 +156,7 @@ public void save(OutputStream os, String type) throws IOException { /** {@inheritDoc} */ @Override - public void drawBoundingBoxes(DetectedObjects detections) { + public void drawBoundingBoxes(DetectedObjects detections, float opacity) { int imageWidth = image.width(); int imageHeight = image.height(); @@ -191,7 +192,7 @@ public void drawBoundingBoxes(DetectedObjects detections) { if (box instanceof Mask) { Mask mask = (Mask) box; BufferedImage img = mat2Image(image); - drawMask(img, mask); + drawMask(img, mask, 0.5f); image = image2Mat(img); } else if (box instanceof Landmark) { drawLandmarks(box); @@ -199,6 +200,23 @@ public void drawBoundingBoxes(DetectedObjects detections) { } } + /** {@inheritDoc} */ + @Override + public void drawMarks(List points, int radius) { + Scalar color = new Scalar(190, 150, 37); + for (ai.djl.modality.cv.output.Point point : points) { + int[][] star = createStar(point, radius); + Point[] mat = new Point[10]; + for (int i = 0; i < 10; ++i) { + mat[i] = new Point(star[0][i], star[1][i]); + } + MatOfPoint mop = new MatOfPoint(); + mop.fromArray(mat); + List ppt = Collections.singletonList(mop); + Imgproc.fillPoly(image, ppt, color, Imgproc.LINE_AA); + } + } + /** {@inheritDoc} */ @Override public void drawJoints(Joints joints) { @@ -370,7 +388,7 @@ private void drawLandmarks(BoundingBox box) { } } - private void drawMask(BufferedImage img, Mask mask) { + private void drawMask(BufferedImage img, Mask mask, float ratio) { // TODO: use OpenCV native way to draw mask float r = RandomUtils.nextFloat(); float g = RandomUtils.nextFloat(); @@ -395,13 +413,15 @@ private void drawMask(BufferedImage img, Mask mask) { } } float[][] probDist = mask.getProbDist(); - float max = 0; - for (float[] row : probDist) { - for (float f : row) { - max = Math.max(max, f); + if (ratio < 0 || ratio > 1) { + float max = 0; + for (float[] row : probDist) { + for (float f : row) { + max = Math.max(max, f); + } } + ratio = 0.5f / max; } - float ratio = 0.5f / max; BufferedImage maskImage = new BufferedImage(probDist[0].length, probDist.length, BufferedImage.TYPE_INT_ARGB); diff --git a/extensions/opencv/src/test/java/ai/djl/opencv/OpenCVImageFactoryTest.java b/extensions/opencv/src/test/java/ai/djl/opencv/OpenCVImageFactoryTest.java index eda79cb45ef..a9f41d087d0 100644 --- a/extensions/opencv/src/test/java/ai/djl/opencv/OpenCVImageFactoryTest.java +++ b/extensions/opencv/src/test/java/ai/djl/opencv/OpenCVImageFactoryTest.java @@ -20,6 +20,7 @@ import ai.djl.modality.cv.output.Joints; import ai.djl.modality.cv.output.Landmark; import ai.djl.modality.cv.output.Mask; +import ai.djl.modality.cv.output.Point; import ai.djl.modality.cv.output.Rectangle; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; @@ -117,6 +118,8 @@ public void testImage() throws IOException { Joints joints = new Joints(jointList); imgCopy.drawJoints(joints); + imgCopy.drawMarks(Collections.singletonList(new Point(20, 20))); + try (OutputStream os = Files.newOutputStream(Paths.get("build/newImage.png"))) { imgCopy.save(os, "png"); } diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArrayEx.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArrayEx.java index f30037bf55a..3f84c1e062d 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArrayEx.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArrayEx.java @@ -464,6 +464,12 @@ public NDList lstm( throw new UnsupportedOperationException("Not implemented"); } + /** {@inheritDoc} */ + @Override + public NDArray interpolation(long[] size, int mode, boolean alignCorners) { + throw new UnsupportedOperationException("Not implemented"); + } + /** {@inheritDoc} */ @Override public RsNDArray resize(int width, int height, int interpolation) {