Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jun 16, 2023
1 parent 661f215 commit c22ee77
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 37 deletions.
58 changes: 25 additions & 33 deletions api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,13 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException {
new GreedyBatchTensorList(inputIds, null, null, attentionMask);
while (true) {
try (NDScope ignore = new NDScope()) {
long pastSeqLength =
searchState.getPastOutputIds() == null
? 0
: searchState.getPastOutputIds().getShape().getLastDimension();
NDList modelInput =
prepareInput(
searchState.getNextInputIds(),
searchState.getPastAttentionMask(),
pastSeqLength,
1);
NDArray pastOutputIds = searchState.getPastOutputIds();
NDArray nextInputIds = searchState.getNextInputIds();
NDArray pastAttentionMask = searchState.getPastAttentionMask();
NDList pastKeyValues = searchState.getPastKeyValues();
long pastSeqLength =
pastOutputIds == null ? 0 : pastOutputIds.getShape().getLastDimension();
NDList modelInput = prepareInput(nextInputIds, pastAttentionMask, pastSeqLength, 1);
if (pastKeyValues != null) {
modelInput.addAll(pastKeyValues);
}
Expand All @@ -69,31 +65,27 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException {
NDArray outputIds = StepGeneration.greedyStepGen(modelOutput.getLogits());

// Update searchState
if (searchState.getPastOutputIds() == null) {
searchState.setPastOutputIds(searchState.getNextInputIds());
if (pastOutputIds == null) {
pastOutputIds = nextInputIds;
searchState.setPastOutputIds(pastOutputIds);
} else {
searchState.setPastOutputIds(
searchState
.getPastOutputIds()
.concat(searchState.getNextInputIds(), 1));
pastOutputIds = pastOutputIds.concat(nextInputIds, 1);
searchState.setPastOutputIds(pastOutputIds);
}
searchState.setNextInputIds(outputIds);
searchState.setPastKeyValues(modelOutput.getPastKeyValuesList());
searchState.setPastAttentionMask(
searchState
.getPastAttentionMask()
.concat(
manager.ones(
new Shape(inputIds.getShape().get(0), 1),
DataType.INT64),
1));
nextInputIds = outputIds;
searchState.setNextInputIds(nextInputIds);
pastKeyValues = modelOutput.getPastKeyValuesList();
searchState.setPastKeyValues(pastKeyValues);
pastAttentionMask =
pastAttentionMask.concat(
manager.ones(
new Shape(inputIds.getShape().get(0), 1), DataType.INT64),
1);
searchState.setPastAttentionMask(pastAttentionMask);

// memory management
NDScope.unregister(
searchState.getNextInputIds(),
searchState.getPastAttentionMask(),
searchState.getPastOutputIds());
NDScope.unregister(searchState.getPastKeyValues());
NDScope.unregister(nextInputIds, pastAttentionMask, pastOutputIds);
NDScope.unregister(pastKeyValues);
}

// Termination Criteria
Expand All @@ -109,7 +101,7 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException {
* Generates text using beam search.
*
* @param inputIds input tokens ids
* @see https://huggingface.co/blog/how-to-generate
* @see <a href="https://huggingface.co/blog/how-to-generate">Beam Search</a>
* @return output tensor
* @throws TranslateException if failed run forward
*/
Expand Down Expand Up @@ -227,7 +219,7 @@ public NDArray beamSearch(NDArray inputIds) throws TranslateException {
* Generates text using contrastive search.
*
* @param inputIds input token ids
* @see https://huggingface.co/blog/introducing-csearch
* @see <a href="https://huggingface.co/blog/introducing-csearch">Contrastive Search</a>
* @return the generated {@code NDArray}
* @throws TranslateException if forward failed
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws
NDList pastKeyValuesOutput = output.subNDList(1, numLayers * 2 + 1);
NDArray hiddenStatesOutput;
if (output.size() > numLayers * 2 + 2) {
// TODO: Why this can happen?
hiddenStatesOutput = output.subNDList(numLayers * 2 + 2).get(0);
// TODO: Should this be 2 * numberLayers + 1?
hiddenStatesOutput = output.get(numLayers * 2 + 1);
} else {
// TODO: In which case this will happen?
hiddenStatesOutput = manager.zeros(new Shape(1));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.pytorch.zoo.nlp.textgeneration;

import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

public class TextGenerationTest {

@Test
public void testGpt2() throws TranslateException, ModelException, IOException {
Block block =
new LambdaBlock(
a -> {
NDList list = new NDList(25);
NDManager manager = a.getManager();
NDArray arr = manager.full(new Shape(1, 4, 50257), 257);
list.add(arr);

for (int i = 0; i < 12 * 2; ++i) {
NDArray array = manager.zeros(new Shape(1, 12, 1, 64));
list.add(array);
}
return list;
},
"model");

Path modelDir = Paths.get("build/text_generation");
Files.createDirectories(modelDir);

Criteria<NDList, CausalLMOutput> criteria =
Criteria.builder()
.setTypes(NDList.class, CausalLMOutput.class)
.optModelPath(modelDir)
.optBlock(block)
.optOption("hasParameter", "false")
.optTranslatorFactory(new PtGptTranslatorFactory())
.build();

try (ZooModel<NDList, CausalLMOutput> model = criteria.loadModel();
Predictor<NDList, CausalLMOutput> predictor = model.newPredictor();
NDManager manager = NDManager.newBaseManager()) {
long[][] inputIds = {{29744, 28478, 5834, 318}};
NDArray input = manager.create(inputIds);
CausalLMOutput res = predictor.predict(new NDList(input));
NDArray logits = res.getLogits();
Assert.assertEquals(logits.argMax().getLong(), 257);
NDList list = res.getPastKeyValuesList();
Assert.assertEquals(list.size(), 24);
Assert.assertEquals(res.getHiddenState().getShape().get(0), 1);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
/**
* Contains unit test classes for text generation.
*/
package ai.djl.pytorch.zoo.nlp.textgeneration;
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ public class TextGenerationTest {

@Test
public void testTextGeneration() throws TranslateException, ModelException, IOException {
TestRequirements.nightly();
TestRequirements.engine("PyTorch");
// TestRequirements.nightly();
// TestRequirements.engine("PyTorch");

String expected =
"DeepMind Company is a global leader in the field of artificial"
Expand Down

0 comments on commit c22ee77

Please sign in to comment.