Skip to content

Commit

Permalink
[api] Adds mask generation task for sam2 model (#3450)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Sep 3, 2024
1 parent 555c596 commit cde2221
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 3 deletions.
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ public static Application of(String path) {
case "cv/instance_segmentation":
case "instance_segmentation":
return CV.INSTANCE_SEGMENTATION;
case "cv/mask_generation":
case "mask_generation":
return CV.MASK_GENERATION;
case "cv/pose_estimation":
case "pose_estimation":
return CV.POSE_ESTIMATION;
Expand Down Expand Up @@ -196,6 +199,12 @@ public interface CV {
*/
Application INSTANCE_SEGMENTATION = new Application("cv/instance_segmentation");

/**
* An application that generates masks that identify a specific object or region of interest
* in a given image.
*/
Application MASK_GENERATION = new Application("cv/mask_generation");

/**
* An application that accepts an image of a single person and returns the {@link
* ai.djl.modality.cv.output.Joints} locations of the person.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ public class PtModelZoo extends ModelZoo {
addModel(
REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet18_embedding", "0.0.1"));
addModel(REPOSITORY.model(CV.INSTANCE_SEGMENTATION, GROUP_ID, "yolov8n-seg", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "sam2-hiera-tiny", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "sam2-hiera-large", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-tiny", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-tiny-gpu", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-large", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-large-gpu", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov5s", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1"));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/mask_generation",
"groupId": "ai.djl.pytorch",
"artifactId": "sam2-hiera-large-gpu",
"name": "Mask generation",
"description": "Segment Anything in Images",
"website": "http://www.djl.ai/engines/pytorch/model-zoo",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "sam2-hiera-large-gpu",
"arguments": {
"translatorFactory": "ai.djl.modality.cv.translator.Sam2TranslatorFactory"
},
"options": {
"mapLocation": "true"
},
"files": {
"model": {
"uri": "0.0.1/sam2-hiera-large-gpu.zip",
"name": "",
"sha1Hash": "0fb0399ca091edf54378348b7b99777bf8776603",
"size": 834565732
}
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/mask_generation",
"groupId": "ai.djl.pytorch",
"artifactId": "sam2-hiera-large",
"name": "Mask generation",
"description": "Segment Anything in Images",
"website": "http://www.djl.ai/engines/pytorch/model-zoo",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "sam2-hiera-large",
"arguments": {
"translatorFactory": "ai.djl.modality.cv.translator.Sam2TranslatorFactory"
},
"options": {
"mapLocation": "true"
},
"files": {
"model": {
"uri": "0.0.1/sam2-hiera-large.zip",
"name": "",
"sha1Hash": "5688c31f52ae086e0c17dd235f4047245dc42eb3",
"size": 834572454
}
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/mask_generation",
"groupId": "ai.djl.pytorch",
"artifactId": "sam2-hiera-tiny-gpu",
"name": "Mask generation",
"description": "Segment Anything in Images",
"website": "http://www.djl.ai/engines/pytorch/model-zoo",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "sam2-hiera-tiny-gpu",
"arguments": {
"translatorFactory": "ai.djl.modality.cv.translator.Sam2TranslatorFactory"
},
"options": {
"mapLocation": "true"
},
"files": {
"model": {
"uri": "0.0.1/sam2-hiera-tiny-gpu.zip",
"name": "",
"sha1Hash": "41440632b2f2d481282b8cd7004d37cc3c6f9a16",
"size": 145037570
}
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/mask_generation",
"groupId": "ai.djl.pytorch",
"artifactId": "sam2-hiera-tiny",
"name": "Mask generation",
"description": "Segment Anything in Images",
"website": "http://www.djl.ai/engines/pytorch/model-zoo",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "sam2-hiera-tiny",
"arguments": {
"translatorFactory": "ai.djl.modality.cv.translator.Sam2TranslatorFactory"
},
"options": {
"mapLocation": "true"
},
"files": {
"model": {
"uri": "0.0.1/sam2-hiera-tiny.zip",
"name": "",
"sha1Hash": "c1eb858f0e8d53c7ec7c94434cd39b69d61db449",
"size": 145062696
}
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/object_detection",
"groupId": "ai.djl.pytorch",
"artifactId": "sam2-hiera-large-gpu",
"name": "Mask generation",
"description": "Segment Anything in Images",
"website": "http://www.djl.ai/engines/pytorch/model-zoo",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "sam2-hiera-large-gpu",
"arguments": {
"translatorFactory": "ai.djl.modality.cv.translator.Sam2TranslatorFactory"
},
"options": {
"mapLocation": "true"
},
"files": {
"model": {
"uri": "0.0.1/sam2-hiera-large-gpu.zip",
"name": "",
"sha1Hash": "0fb0399ca091edf54378348b7b99777bf8776603",
"size": 834565732
}
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/object_detection",
"groupId": "ai.djl.pytorch",
"artifactId": "sam2-hiera-tiny-gpu",
"name": "Mask generation",
"description": "Segment Anything in Images",
"website": "http://www.djl.ai/engines/pytorch/model-zoo",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "sam2-hiera-tiny-gpu",
"arguments": {
"translatorFactory": "ai.djl.modality.cv.translator.Sam2TranslatorFactory"
},
"options": {
"mapLocation": "true"
},
"files": {
"model": {
"uri": "0.0.1/sam2-hiera-tiny-gpu.zip",
"name": "",
"sha1Hash": "41440632b2f2d481282b8cd7004d37cc3c6f9a16",
"size": 145037570
}
}
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran
.setTypes(Sam2Input.class, DetectedObjects.class)
.optModelUrls("djl://ai.djl.pytorch/sam2-hiera-tiny")
.optEngine("PyTorch")
.optDevice(Device.cpu()) // this model only works on CPU
.optDevice(Device.cpu()) // use sam2-hiera-tiny-gpu for GPU
.optTranslator(new Sam2Translator())
.optProgress(new ProgressBar())
.build();
Expand Down

0 comments on commit cde2221

Please sign in to comment.