diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java
index 5ffb816efc6..e2bc0a7b33f 100644
--- a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java
+++ b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java
@@ -50,17 +50,13 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException {
new GreedyBatchTensorList(inputIds, null, null, attentionMask);
while (true) {
try (NDScope ignore = new NDScope()) {
- long pastSeqLength =
- searchState.getPastOutputIds() == null
- ? 0
- : searchState.getPastOutputIds().getShape().getLastDimension();
- NDList modelInput =
- prepareInput(
- searchState.getNextInputIds(),
- searchState.getPastAttentionMask(),
- pastSeqLength,
- 1);
+ NDArray pastOutputIds = searchState.getPastOutputIds();
+ NDArray nextInputIds = searchState.getNextInputIds();
+ NDArray pastAttentionMask = searchState.getPastAttentionMask();
NDList pastKeyValues = searchState.getPastKeyValues();
+ long pastSeqLength =
+ pastOutputIds == null ? 0 : pastOutputIds.getShape().getLastDimension();
+ NDList modelInput = prepareInput(nextInputIds, pastAttentionMask, pastSeqLength, 1);
if (pastKeyValues != null) {
modelInput.addAll(pastKeyValues);
}
@@ -69,31 +65,27 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException {
NDArray outputIds = StepGeneration.greedyStepGen(modelOutput.getLogits());
// Update searchState
- if (searchState.getPastOutputIds() == null) {
- searchState.setPastOutputIds(searchState.getNextInputIds());
+ if (pastOutputIds == null) {
+ pastOutputIds = nextInputIds;
+ searchState.setPastOutputIds(pastOutputIds);
} else {
- searchState.setPastOutputIds(
- searchState
- .getPastOutputIds()
- .concat(searchState.getNextInputIds(), 1));
+ pastOutputIds = pastOutputIds.concat(nextInputIds, 1);
+ searchState.setPastOutputIds(pastOutputIds);
}
- searchState.setNextInputIds(outputIds);
- searchState.setPastKeyValues(modelOutput.getPastKeyValuesList());
- searchState.setPastAttentionMask(
- searchState
- .getPastAttentionMask()
- .concat(
- manager.ones(
- new Shape(inputIds.getShape().get(0), 1),
- DataType.INT64),
- 1));
+ nextInputIds = outputIds;
+ searchState.setNextInputIds(nextInputIds);
+ pastKeyValues = modelOutput.getPastKeyValuesList();
+ searchState.setPastKeyValues(pastKeyValues);
+ pastAttentionMask =
+ pastAttentionMask.concat(
+ manager.ones(
+ new Shape(inputIds.getShape().get(0), 1), DataType.INT64),
+ 1);
+ searchState.setPastAttentionMask(pastAttentionMask);
// memory management
- NDScope.unregister(
- searchState.getNextInputIds(),
- searchState.getPastAttentionMask(),
- searchState.getPastOutputIds());
- NDScope.unregister(searchState.getPastKeyValues());
+ NDScope.unregister(nextInputIds, pastAttentionMask, pastOutputIds);
+ NDScope.unregister(pastKeyValues);
}
// Termination Criteria
@@ -109,7 +101,7 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException {
* Generates text using beam search.
*
* @param inputIds input tokens ids
- * @see https://huggingface.co/blog/how-to-generate
+ * @see Beam Search
* @return output tensor
* @throws TranslateException if failed run forward
*/
@@ -227,7 +219,7 @@ public NDArray beamSearch(NDArray inputIds) throws TranslateException {
* Generates text using contrastive search.
*
* @param inputIds input token ids
- * @see https://huggingface.co/blog/introducing-csearch
+ * @see Contrastive Search
* @return the generated {@code NDArray}
* @throws TranslateException if forward failed
*/
diff --git a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java
index 54d71fe3f64..383712709b5 100644
--- a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java
+++ b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java
@@ -62,9 +62,10 @@ public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws
NDList pastKeyValuesOutput = output.subNDList(1, numLayers * 2 + 1);
NDArray hiddenStatesOutput;
if (output.size() > numLayers * 2 + 2) {
- // TODO: Why this can happen?
- hiddenStatesOutput = output.subNDList(numLayers * 2 + 2).get(0);
+ // TODO: Should this be 2 * numberLayers + 1?
+ hiddenStatesOutput = output.get(numLayers * 2 + 1);
} else {
+ // TODO: In which case this will happen?
hiddenStatesOutput = manager.zeros(new Shape(1));
}
diff --git a/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/TextGenerationTest.java b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/TextGenerationTest.java
new file mode 100644
index 00000000000..d8c9b0fbc38
--- /dev/null
+++ b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/TextGenerationTest.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.pytorch.zoo.nlp.textgeneration;
+
+import ai.djl.ModelException;
+import ai.djl.inference.Predictor;
+import ai.djl.modality.nlp.generate.CausalLMOutput;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.types.Shape;
+import ai.djl.nn.Block;
+import ai.djl.nn.LambdaBlock;
+import ai.djl.repository.zoo.Criteria;
+import ai.djl.repository.zoo.ZooModel;
+import ai.djl.translate.TranslateException;
+
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+
+public class TextGenerationTest {
+
+ @Test
+ public void testGpt2() throws TranslateException, ModelException, IOException {
+ Block block =
+ new LambdaBlock(
+ a -> {
+ NDList list = new NDList(25);
+ NDManager manager = a.getManager();
+ NDArray arr = manager.full(new Shape(1, 4, 50257), 257);
+ list.add(arr);
+
+ for (int i = 0; i < 12 * 2; ++i) {
+ NDArray array = manager.zeros(new Shape(1, 12, 1, 64));
+ list.add(array);
+ }
+ return list;
+ },
+ "model");
+
+ Path modelDir = Paths.get("build/text_generation");
+ Files.createDirectories(modelDir);
+
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(NDList.class, CausalLMOutput.class)
+ .optModelPath(modelDir)
+ .optBlock(block)
+ .optOption("hasParameter", "false")
+ .optTranslatorFactory(new PtGptTranslatorFactory())
+ .build();
+
+ try (ZooModel model = criteria.loadModel();
+ Predictor predictor = model.newPredictor();
+ NDManager manager = NDManager.newBaseManager()) {
+ long[][] inputIds = {{29744, 28478, 5834, 318}};
+ NDArray input = manager.create(inputIds);
+ CausalLMOutput res = predictor.predict(new NDList(input));
+ NDArray logits = res.getLogits();
+ Assert.assertEquals(logits.argMax().getLong(), 257);
+ NDList list = res.getPastKeyValuesList();
+ Assert.assertEquals(list.size(), 24);
+ Assert.assertEquals(res.getHiddenState().getShape().get(0), 1);
+ }
+ }
+}
diff --git a/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/package-info.java b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/package-info.java
new file mode 100644
index 00000000000..e1c0aaa1791
--- /dev/null
+++ b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/package-info.java
@@ -0,0 +1,16 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+/**
+ * Contains unit test classes for text generation.
+ */
+package ai.djl.pytorch.zoo.nlp.textgeneration;
diff --git a/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java b/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java
index 3e17738c687..c1d855ec89c 100644
--- a/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java
+++ b/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java
@@ -25,8 +25,8 @@ public class TextGenerationTest {
@Test
public void testTextGeneration() throws TranslateException, ModelException, IOException {
- TestRequirements.nightly();
- TestRequirements.engine("PyTorch");
+// TestRequirements.nightly();
+// TestRequirements.engine("PyTorch");
String expected =
"DeepMind Company is a global leader in the field of artificial"