diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java index 5ffb816efc6..e2bc0a7b33f 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java @@ -50,17 +50,13 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException { new GreedyBatchTensorList(inputIds, null, null, attentionMask); while (true) { try (NDScope ignore = new NDScope()) { - long pastSeqLength = - searchState.getPastOutputIds() == null - ? 0 - : searchState.getPastOutputIds().getShape().getLastDimension(); - NDList modelInput = - prepareInput( - searchState.getNextInputIds(), - searchState.getPastAttentionMask(), - pastSeqLength, - 1); + NDArray pastOutputIds = searchState.getPastOutputIds(); + NDArray nextInputIds = searchState.getNextInputIds(); + NDArray pastAttentionMask = searchState.getPastAttentionMask(); NDList pastKeyValues = searchState.getPastKeyValues(); + long pastSeqLength = + pastOutputIds == null ? 0 : pastOutputIds.getShape().getLastDimension(); + NDList modelInput = prepareInput(nextInputIds, pastAttentionMask, pastSeqLength, 1); if (pastKeyValues != null) { modelInput.addAll(pastKeyValues); } @@ -69,31 +65,27 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException { NDArray outputIds = StepGeneration.greedyStepGen(modelOutput.getLogits()); // Update searchState - if (searchState.getPastOutputIds() == null) { - searchState.setPastOutputIds(searchState.getNextInputIds()); + if (pastOutputIds == null) { + pastOutputIds = nextInputIds; + searchState.setPastOutputIds(pastOutputIds); } else { - searchState.setPastOutputIds( - searchState - .getPastOutputIds() - .concat(searchState.getNextInputIds(), 1)); + pastOutputIds = pastOutputIds.concat(nextInputIds, 1); + searchState.setPastOutputIds(pastOutputIds); } - searchState.setNextInputIds(outputIds); - searchState.setPastKeyValues(modelOutput.getPastKeyValuesList()); - searchState.setPastAttentionMask( - searchState - .getPastAttentionMask() - .concat( - manager.ones( - new Shape(inputIds.getShape().get(0), 1), - DataType.INT64), - 1)); + nextInputIds = outputIds; + searchState.setNextInputIds(nextInputIds); + pastKeyValues = modelOutput.getPastKeyValuesList(); + searchState.setPastKeyValues(pastKeyValues); + pastAttentionMask = + pastAttentionMask.concat( + manager.ones( + new Shape(inputIds.getShape().get(0), 1), DataType.INT64), + 1); + searchState.setPastAttentionMask(pastAttentionMask); // memory management - NDScope.unregister( - searchState.getNextInputIds(), - searchState.getPastAttentionMask(), - searchState.getPastOutputIds()); - NDScope.unregister(searchState.getPastKeyValues()); + NDScope.unregister(nextInputIds, pastAttentionMask, pastOutputIds); + NDScope.unregister(pastKeyValues); } // Termination Criteria @@ -109,7 +101,7 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException { * Generates text using beam search. * * @param inputIds input tokens ids - * @see https://huggingface.co/blog/how-to-generate + * @see Beam Search * @return output tensor * @throws TranslateException if failed run forward */ @@ -227,7 +219,7 @@ public NDArray beamSearch(NDArray inputIds) throws TranslateException { * Generates text using contrastive search. * * @param inputIds input token ids - * @see https://huggingface.co/blog/introducing-csearch + * @see Contrastive Search * @return the generated {@code NDArray} * @throws TranslateException if forward failed */ diff --git a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java index 54d71fe3f64..383712709b5 100644 --- a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java +++ b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java @@ -62,9 +62,10 @@ public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws NDList pastKeyValuesOutput = output.subNDList(1, numLayers * 2 + 1); NDArray hiddenStatesOutput; if (output.size() > numLayers * 2 + 2) { - // TODO: Why this can happen? - hiddenStatesOutput = output.subNDList(numLayers * 2 + 2).get(0); + // TODO: Should this be 2 * numberLayers + 1? + hiddenStatesOutput = output.get(numLayers * 2 + 1); } else { + // TODO: In which case this will happen? hiddenStatesOutput = manager.zeros(new Shape(1)); } diff --git a/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/TextGenerationTest.java b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/TextGenerationTest.java new file mode 100644 index 00000000000..d8c9b0fbc38 --- /dev/null +++ b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/TextGenerationTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.zoo.nlp.textgeneration; + +import ai.djl.ModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.nlp.generate.CausalLMOutput; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Block; +import ai.djl.nn.LambdaBlock; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.TranslateException; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +public class TextGenerationTest { + + @Test + public void testGpt2() throws TranslateException, ModelException, IOException { + Block block = + new LambdaBlock( + a -> { + NDList list = new NDList(25); + NDManager manager = a.getManager(); + NDArray arr = manager.full(new Shape(1, 4, 50257), 257); + list.add(arr); + + for (int i = 0; i < 12 * 2; ++i) { + NDArray array = manager.zeros(new Shape(1, 12, 1, 64)); + list.add(array); + } + return list; + }, + "model"); + + Path modelDir = Paths.get("build/text_generation"); + Files.createDirectories(modelDir); + + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, CausalLMOutput.class) + .optModelPath(modelDir) + .optBlock(block) + .optOption("hasParameter", "false") + .optTranslatorFactory(new PtGptTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor(); + NDManager manager = NDManager.newBaseManager()) { + long[][] inputIds = {{29744, 28478, 5834, 318}}; + NDArray input = manager.create(inputIds); + CausalLMOutput res = predictor.predict(new NDList(input)); + NDArray logits = res.getLogits(); + Assert.assertEquals(logits.argMax().getLong(), 257); + NDList list = res.getPastKeyValuesList(); + Assert.assertEquals(list.size(), 24); + Assert.assertEquals(res.getHiddenState().getShape().get(0), 1); + } + } +} diff --git a/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/package-info.java b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/package-info.java new file mode 100644 index 00000000000..e1c0aaa1791 --- /dev/null +++ b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/package-info.java @@ -0,0 +1,16 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** + * Contains unit test classes for text generation. + */ +package ai.djl.pytorch.zoo.nlp.textgeneration; diff --git a/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java b/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java index 3e17738c687..c1d855ec89c 100644 --- a/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java @@ -25,8 +25,8 @@ public class TextGenerationTest { @Test public void testTextGeneration() throws TranslateException, ModelException, IOException { - TestRequirements.nightly(); - TestRequirements.engine("PyTorch"); +// TestRequirements.nightly(); +// TestRequirements.engine("PyTorch"); String expected = "DeepMind Company is a global leader in the field of artificial"