From 7aa1af1c648faa305d50c584b51cd07e064e2d1e Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Thu, 5 Sep 2024 16:30:49 -0700 Subject: [PATCH] [api] Refactor ImageFeatureExtractor (#3455) --- .../cv/translator/ImageFeatureExtractor.java | 7 +- .../ImageFeatureExtractorFactory.java | 9 +-- .../inference/face/FeatureExtraction.java | 64 +++++-------------- .../inference/face/FeatureExtractionTest.java | 18 +++++- 4 files changed, 40 insertions(+), 58 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractor.java b/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractor.java index 8edde7a34d5..cb19ae7405b 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractor.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractor.java @@ -13,6 +13,7 @@ package ai.djl.modality.cv.translator; import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.DataType; import ai.djl.translate.TranslatorContext; import java.util.Map; @@ -20,7 +21,7 @@ /** * A generic {@link ai.djl.translate.Translator} for Image Classification feature extraction tasks. */ -public class ImageFeatureExtractor extends BaseImageTranslator { +public class ImageFeatureExtractor extends BaseImageTranslator { /** * Constructs an Image Classification using {@link Builder}. @@ -33,8 +34,8 @@ public class ImageFeatureExtractor extends BaseImageTranslator { /** {@inheritDoc} */ @Override - public byte[] processOutput(TranslatorContext ctx, NDList list) { - return list.get(0).toByteArray(); + public float[] processOutput(TranslatorContext ctx, NDList list) { + return list.get(0).toType(DataType.FLOAT32, false).toFloatArray(); } /** diff --git a/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractorFactory.java index 82e9344a1aa..620b1a378f1 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractorFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractorFactory.java @@ -21,20 +21,21 @@ import java.util.Map; /** A {@link TranslatorFactory} that creates an {@link ImageClassificationTranslator}. */ -public class ImageFeatureExtractorFactory extends BaseImageTranslatorFactory +public class ImageFeatureExtractorFactory extends BaseImageTranslatorFactory implements Serializable { private static final long serialVersionUID = 1L; /** {@inheritDoc} */ @Override - protected Translator buildBaseTranslator(Model model, Map arguments) { + protected Translator buildBaseTranslator( + Model model, Map arguments) { return ImageFeatureExtractor.builder(arguments).build(); } /** {@inheritDoc} */ @Override - public Class getBaseOutputType() { - return byte[].class; + public Class getBaseOutputType() { + return float[].class; } } diff --git a/examples/src/main/java/ai/djl/examples/inference/face/FeatureExtraction.java b/examples/src/main/java/ai/djl/examples/inference/face/FeatureExtraction.java index 550cecdc307..f941fb58b73 100644 --- a/examples/src/main/java/ai/djl/examples/inference/face/FeatureExtraction.java +++ b/examples/src/main/java/ai/djl/examples/inference/face/FeatureExtraction.java @@ -16,17 +16,11 @@ import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; -import ai.djl.modality.cv.transform.Normalize; -import ai.djl.modality.cv.transform.ToTensor; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDList; +import ai.djl.modality.cv.translator.ImageFeatureExtractorFactory; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.util.ProgressBar; -import ai.djl.translate.Pipeline; import ai.djl.translate.TranslateException; -import ai.djl.translate.Translator; -import ai.djl.translate.TranslatorContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,6 +29,8 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; public final class FeatureExtraction { @@ -55,13 +51,25 @@ public static void main(String[] args) throws IOException, ModelException, Trans public static float[] predict(Image img) throws IOException, ModelException, TranslateException { img.getWrappedImage(); + + List mean = + Arrays.asList( + 127.5f / 255.0f, + 127.5f / 255.0f, + 127.5f / 255.0f, + 128.0f / 255.0f, + 128.0f / 255.0f, + 128.0f / 255.0f); + String normalize = mean.stream().map(Object::toString).collect(Collectors.joining(",")); + Criteria criteria = Criteria.builder() .setTypes(Image.class, float[].class) .optModelUrls( "https://resources.djl.ai/test-models/pytorch/face_feature.zip") .optModelName("face_feature") // specify model file prefix - .optTranslator(new FaceFeatureTranslator()) + .optArgument("normalize", normalize) + .optTranslatorFactory(new ImageFeatureExtractorFactory()) .optProgress(new ProgressBar()) .optEngine("PyTorch") // Use PyTorch engine .build(); @@ -71,44 +79,4 @@ public static float[] predict(Image img) return predictor.predict(img); } } - - private static final class FaceFeatureTranslator implements Translator { - - FaceFeatureTranslator() {} - - /** {@inheritDoc} */ - @Override - public NDList processInput(TranslatorContext ctx, Image input) { - NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR); - Pipeline pipeline = new Pipeline(); - pipeline - // .add(new Resize(160)) - .add(new ToTensor()) - .add( - new Normalize( - new float[] {127.5f / 255.0f, 127.5f / 255.0f, 127.5f / 255.0f}, - new float[] { - 128.0f / 255.0f, 128.0f / 255.0f, 128.0f / 255.0f - })); - - return pipeline.transform(new NDList(array)); - } - - /** {@inheritDoc} */ - @Override - public float[] processOutput(TranslatorContext ctx, NDList list) { - NDList result = new NDList(); - long numOutputs = list.singletonOrThrow().getShape().get(0); - for (int i = 0; i < numOutputs; i++) { - result.add(list.singletonOrThrow().get(i)); - } - float[][] embeddings = - result.stream().map(NDArray::toFloatArray).toArray(float[][]::new); - float[] feature = new float[embeddings.length]; - for (int i = 0; i < embeddings.length; i++) { - feature[i] = embeddings[i][0]; - } - return feature; - } - } } diff --git a/examples/src/test/java/ai/djl/examples/inference/face/FeatureExtractionTest.java b/examples/src/test/java/ai/djl/examples/inference/face/FeatureExtractionTest.java index af50b240a0a..e3c5cbe1fc3 100644 --- a/examples/src/test/java/ai/djl/examples/inference/face/FeatureExtractionTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/face/FeatureExtractionTest.java @@ -15,7 +15,6 @@ import ai.djl.ModelException; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; -import ai.djl.testing.TestRequirements; import ai.djl.translate.TranslateException; import org.testng.Assert; @@ -29,11 +28,24 @@ public class FeatureExtractionTest { @Test public void testFeatureComparison() throws ModelException, TranslateException, IOException { - TestRequirements.linux(); - Path imageFile = Paths.get("src/test/resources/kana1.jpg"); Image img = ImageFactory.getInstance().fromFile(imageFile); float[] feature = FeatureExtraction.predict(img); Assert.assertEquals(feature.length, 512); + float[] expected = { + -0.040261813f, + -0.019486334f, + -0.09802657f, + 0.017009983f, + 0.037828982f, + 0.030801114f, + -0.02714689f, + 0.042024296f, + -0.009838469f, + -0.005961003f + }; + float[] sub = new float[10]; + System.arraycopy(feature, 0, sub, 0, 10); + Assert.assertEquals(sub, expected, 0.0001f); } }