From 661f2155f4b98d4f9e8e17c060d1b4e19fe14bb3 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 13 Jun 2023 16:41:14 -0700 Subject: [PATCH] Refactor TextGenerator API --- .../textgeneration/PtGptTranslatorFactory.java | 6 +++--- .../inference/{ => nlp}/TextGenerationTest.java | 3 +-- .../djl/examples/inference/nlp/package-info.java | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 5 deletions(-) rename examples/src/test/java/ai/djl/examples/inference/{ => nlp}/TextGenerationTest.java (94%) create mode 100644 examples/src/test/java/ai/djl/examples/inference/nlp/package-info.java diff --git a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslatorFactory.java b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslatorFactory.java index 2c8de208d4e..78b0de17313 100644 --- a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslatorFactory.java +++ b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslatorFactory.java @@ -25,7 +25,7 @@ import java.util.Map; import java.util.Set; -public class PtGptTranslatorFactory implements TranslatorFactory { +public class PtGptTranslatorFactory implements TranslatorFactory { private static final Set> SUPPORTED_TYPES = new HashSet<>(); @@ -48,8 +48,8 @@ public Translator newInstance( throw new IllegalArgumentException("Unsupported input/output types."); } long kvDim = ArgumentsUtil.longValue(arguments, "kvDim", 64); - int numAttentionHeads= ArgumentsUtil.intValue(arguments, "numAttentionHeads", 12); - int numLayers= ArgumentsUtil.intValue(arguments, "numLayers", 12); + int numAttentionHeads = ArgumentsUtil.intValue(arguments, "numAttentionHeads", 12); + int numLayers = ArgumentsUtil.intValue(arguments, "numLayers", 12); return (Translator) (new PtGptTranslator(kvDim, numAttentionHeads, numLayers)); } diff --git a/examples/src/test/java/ai/djl/examples/inference/TextGenerationTest.java b/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java similarity index 94% rename from examples/src/test/java/ai/djl/examples/inference/TextGenerationTest.java rename to examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java index 06d9de6189a..3e17738c687 100644 --- a/examples/src/test/java/ai/djl/examples/inference/TextGenerationTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java @@ -10,10 +10,9 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.examples.inference; +package ai.djl.examples.inference.nlp; import ai.djl.ModelException; -import ai.djl.examples.inference.nlp.TextGeneration; import ai.djl.testing.TestRequirements; import ai.djl.translate.TranslateException; diff --git a/examples/src/test/java/ai/djl/examples/inference/nlp/package-info.java b/examples/src/test/java/ai/djl/examples/inference/nlp/package-info.java new file mode 100644 index 00000000000..fa1d19a8000 --- /dev/null +++ b/examples/src/test/java/ai/djl/examples/inference/nlp/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains tests for the inference examples. */ +package ai.djl.examples.inference.nlp;