Skip to content

Commit

Permalink
[api] Refactor ImageFeatureExtractor
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Sep 5, 2024
1 parent 581f1dd commit 7bbbe43
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
package ai.djl.modality.cv.translator;

import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.TranslatorContext;

import java.util.Map;

/**
* A generic {@link ai.djl.translate.Translator} for Image Classification feature extraction tasks.
*/
public class ImageFeatureExtractor extends BaseImageTranslator<byte[]> {
public class ImageFeatureExtractor extends BaseImageTranslator<float[]> {

/**
* Constructs an Image Classification using {@link Builder}.
Expand All @@ -33,8 +34,8 @@ public class ImageFeatureExtractor extends BaseImageTranslator<byte[]> {

/** {@inheritDoc} */
@Override
public byte[] processOutput(TranslatorContext ctx, NDList list) {
return list.get(0).toByteArray();
public float[] processOutput(TranslatorContext ctx, NDList list) {
return list.get(0).toType(DataType.FLOAT32, false).toFloatArray();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@
import java.util.Map;

/** A {@link TranslatorFactory} that creates an {@link ImageClassificationTranslator}. */
public class ImageFeatureExtractorFactory extends BaseImageTranslatorFactory<byte[]>
public class ImageFeatureExtractorFactory extends BaseImageTranslatorFactory<float[]>
implements Serializable {

private static final long serialVersionUID = 1L;

/** {@inheritDoc} */
@Override
protected Translator<Image, byte[]> buildBaseTranslator(Model model, Map<String, ?> arguments) {
protected Translator<Image, float[]> buildBaseTranslator(
Model model, Map<String, ?> arguments) {
return ImageFeatureExtractor.builder(arguments).build();
}

/** {@inheritDoc} */
@Override
public Class<byte[]> getBaseOutputType() {
return byte[].class;
public Class<float[]> getBaseOutputType() {
return float[].class;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,11 @@
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.modality.cv.translator.ImageFeatureExtractorFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -35,6 +29,8 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

public final class FeatureExtraction {

Expand All @@ -55,13 +51,25 @@ public static void main(String[] args) throws IOException, ModelException, Trans
public static float[] predict(Image img)
throws IOException, ModelException, TranslateException {
img.getWrappedImage();

List<Float> mean =
Arrays.asList(
127.5f / 255.0f,
127.5f / 255.0f,
127.5f / 255.0f,
128.0f / 255.0f,
128.0f / 255.0f,
128.0f / 255.0f);
String normalize = mean.stream().map(Object::toString).collect(Collectors.joining(","));

Criteria<Image, float[]> criteria =
Criteria.builder()
.setTypes(Image.class, float[].class)
.optModelUrls(
"https://resources.djl.ai/test-models/pytorch/face_feature.zip")
.optModelName("face_feature") // specify model file prefix
.optTranslator(new FaceFeatureTranslator())
.optArgument("normalize", normalize)
.optTranslatorFactory(new ImageFeatureExtractorFactory())
.optProgress(new ProgressBar())
.optEngine("PyTorch") // Use PyTorch engine
.build();
Expand All @@ -71,44 +79,4 @@ public static float[] predict(Image img)
return predictor.predict(img);
}
}

private static final class FaceFeatureTranslator implements Translator<Image, float[]> {

FaceFeatureTranslator() {}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
Pipeline pipeline = new Pipeline();
pipeline
// .add(new Resize(160))
.add(new ToTensor())
.add(
new Normalize(
new float[] {127.5f / 255.0f, 127.5f / 255.0f, 127.5f / 255.0f},
new float[] {
128.0f / 255.0f, 128.0f / 255.0f, 128.0f / 255.0f
}));

return pipeline.transform(new NDList(array));
}

/** {@inheritDoc} */
@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
NDList result = new NDList();
long numOutputs = list.singletonOrThrow().getShape().get(0);
for (int i = 0; i < numOutputs; i++) {
result.add(list.singletonOrThrow().get(i));
}
float[][] embeddings =
result.stream().map(NDArray::toFloatArray).toArray(float[][]::new);
float[] feature = new float[embeddings.length];
for (int i = 0; i < embeddings.length; i++) {
feature[i] = embeddings[i][0];
}
return feature;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.ModelException;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;

import org.testng.Assert;
Expand All @@ -29,11 +28,24 @@ public class FeatureExtractionTest {

@Test
public void testFeatureComparison() throws ModelException, TranslateException, IOException {
TestRequirements.linux();

Path imageFile = Paths.get("src/test/resources/kana1.jpg");
Image img = ImageFactory.getInstance().fromFile(imageFile);
float[] feature = FeatureExtraction.predict(img);
Assert.assertEquals(feature.length, 512);
float[] expected = {
-0.040261813f,
-0.019486334f,
-0.09802657f,
0.017009983f,
0.037828982f,
0.030801114f,
-0.02714689f,
0.042024296f,
-0.009838469f,
-0.005961003f
};
float[] sub = new float[10];
System.arraycopy(feature, 0, sub, 0, 10);
Assert.assertEquals(sub, expected);
}
}

0 comments on commit 7bbbe43

Please sign in to comment.