Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jun 16, 2023
1 parent 661f215 commit 6511fce
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 37 deletions.
58 changes: 25 additions & 33 deletions api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,13 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException {
new GreedyBatchTensorList(inputIds, null, null, attentionMask);
while (true) {
try (NDScope ignore = new NDScope()) {
long pastSeqLength =
searchState.getPastOutputIds() == null
? 0
: searchState.getPastOutputIds().getShape().getLastDimension();
NDList modelInput =
prepareInput(
searchState.getNextInputIds(),
searchState.getPastAttentionMask(),
pastSeqLength,
1);
NDArray pastOutputIds = searchState.getPastOutputIds();
NDArray nextInputIds = searchState.getNextInputIds();
NDArray pastAttentionMask = searchState.getPastAttentionMask();
NDList pastKeyValues = searchState.getPastKeyValues();
long pastSeqLength =
pastOutputIds == null ? 0 : pastOutputIds.getShape().getLastDimension();
NDList modelInput = prepareInput(nextInputIds, pastAttentionMask, pastSeqLength, 1);
if (pastKeyValues != null) {
modelInput.addAll(pastKeyValues);
}
Expand All @@ -69,31 +65,27 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException {
NDArray outputIds = StepGeneration.greedyStepGen(modelOutput.getLogits());

// Update searchState
if (searchState.getPastOutputIds() == null) {
searchState.setPastOutputIds(searchState.getNextInputIds());
if (pastOutputIds == null) {
pastOutputIds = nextInputIds;
searchState.setPastOutputIds(pastOutputIds);
} else {
searchState.setPastOutputIds(
searchState
.getPastOutputIds()
.concat(searchState.getNextInputIds(), 1));
pastOutputIds = pastOutputIds.concat(nextInputIds, 1);
searchState.setPastOutputIds(pastOutputIds);
}
searchState.setNextInputIds(outputIds);
searchState.setPastKeyValues(modelOutput.getPastKeyValuesList());
searchState.setPastAttentionMask(
searchState
.getPastAttentionMask()
.concat(
manager.ones(
new Shape(inputIds.getShape().get(0), 1),
DataType.INT64),
1));
nextInputIds = outputIds;
searchState.setNextInputIds(nextInputIds);
pastKeyValues = modelOutput.getPastKeyValuesList();
searchState.setPastKeyValues(pastKeyValues);
pastAttentionMask =
pastAttentionMask.concat(
manager.ones(
new Shape(inputIds.getShape().get(0), 1), DataType.INT64),
1);
searchState.setPastAttentionMask(pastAttentionMask);

// memory management
NDScope.unregister(
searchState.getNextInputIds(),
searchState.getPastAttentionMask(),
searchState.getPastOutputIds());
NDScope.unregister(searchState.getPastKeyValues());
NDScope.unregister(nextInputIds, pastAttentionMask, pastOutputIds);
NDScope.unregister(pastKeyValues);
}

// Termination Criteria
Expand All @@ -109,7 +101,7 @@ public NDArray greedySearch(NDArray inputIds) throws TranslateException {
* Generates text using beam search.
*
* @param inputIds input tokens ids
* @see https://huggingface.co/blog/how-to-generate
* @see <a href="https://huggingface.co/blog/how-to-generate">Beam Search</a>
* @return output tensor
* @throws TranslateException if failed run forward
*/
Expand Down Expand Up @@ -227,7 +219,7 @@ public NDArray beamSearch(NDArray inputIds) throws TranslateException {
* Generates text using contrastive search.
*
* @param inputIds input token ids
* @see https://huggingface.co/blog/introducing-csearch
* @see <a href="https://huggingface.co/blog/introducing-csearch">Contrastive Search</a>
* @return the generated {@code NDArray}
* @throws TranslateException if forward failed
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,29 @@

import java.util.stream.Collectors;

/** The {@link ai.djl.translate.Translator} for PyTorch GPT2 model. */
public class PtGptTranslator implements NoBatchifyTranslator<NDList, CausalLMOutput> {

private long kvDim;
private int numAttentionHeads;
private int numLayers;
private String tupleName;

/**
* Constructs a new instance of {@code PtGptTranslator}.
*
* @param kvDim the kv dimension
* @param numAttentionHeads the number of attention heads
* @param numLayers the number of layers
*/
public PtGptTranslator(long kvDim, int numAttentionHeads, int numLayers) {
this.kvDim = kvDim;
this.numAttentionHeads = numAttentionHeads;
this.numLayers = numLayers;
tupleName = "past_key_values(" + numLayers + ',' + 2 + ')';
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, NDList input) throws Exception {
NDManager manager = ctx.getNDManager();
Expand All @@ -55,16 +64,18 @@ public NDList processInput(TranslatorContext ctx, NDList input) throws Exception
return input;
}

/** {@inheritDoc} */
@Override
public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws Exception {
NDArray logitsOutput = output.get(0);
NDManager manager = output.getManager();
NDList pastKeyValuesOutput = output.subNDList(1, numLayers * 2 + 1);
NDArray hiddenStatesOutput;
if (output.size() > numLayers * 2 + 2) {
// TODO: Why this can happen?
hiddenStatesOutput = output.subNDList(numLayers * 2 + 2).get(0);
// TODO: Should this be 2 * numberLayers + 1?
hiddenStatesOutput = output.get(numLayers * 2 + 1);
} else {
// TODO: In which case this will happen?
hiddenStatesOutput = manager.zeros(new Shape(1));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Map;
import java.util.Set;

/** An {@link TranslatorFactory} that creates a {@link PtGptTranslator} instance. */
public class PtGptTranslatorFactory implements TranslatorFactory {

private static final Set<Pair<Type, Type>> SUPPORTED_TYPES = new HashSet<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

/** Contains classes for the {@link ai.djl.Application.NLP#TEXT_GENERATION} models. */
package ai.djl.pytorch.zoo.nlp.textgeneration;
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.pytorch.zoo.nlp.textgeneration;

import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

public class TextGenerationTest {

@Test
public void testGpt2() throws TranslateException, ModelException, IOException {
Block block =
new LambdaBlock(
a -> {
NDList list = new NDList(25);
NDManager manager = a.getManager();
long[][] logits = new long[4][50257];
logits[3][257] = 1;
NDArray arr = manager.create(logits).expandDims(0);
list.add(arr);

for (int i = 0; i < 12 * 2; ++i) {
NDArray array = manager.zeros(new Shape(1, 12, 1, 64));
list.add(array);
}
return list;
},
"model");

Path modelDir = Paths.get("build/text_generation");
Files.createDirectories(modelDir);

Criteria<NDList, CausalLMOutput> criteria =
Criteria.builder()
.setTypes(NDList.class, CausalLMOutput.class)
.optModelPath(modelDir)
.optBlock(block)
.optOption("hasParameter", "false")
.optTranslatorFactory(new PtGptTranslatorFactory())
.build();

try (ZooModel<NDList, CausalLMOutput> model = criteria.loadModel();
Predictor<NDList, CausalLMOutput> predictor = model.newPredictor();
NDManager manager = NDManager.newBaseManager()) {
long[][] inputIds = {{29744, 28478, 5834, 318}};
int len = inputIds[0].length;
NDArray input = manager.create(inputIds);
NDArray attentionMask = manager.ones(new Shape(1, len), DataType.INT64);
NDArray positionIds = manager.arange(0, len, 1, DataType.INT64).expandDims(0);
CausalLMOutput res = predictor.predict(new NDList(input, attentionMask, positionIds));
NDArray logits = res.getLogits();
long nextTokenId = logits.get(":, -1, :").argMax().getLong();
Assert.assertEquals(nextTokenId, 257);
NDList list = res.getPastKeyValuesList();
Assert.assertEquals(list.size(), 24);
Assert.assertEquals(res.getHiddenState().getShape().get(0), 1);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
/**
* Contains unit test classes for text generation.
*/
package ai.djl.pytorch.zoo.nlp.textgeneration;
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ public class TextGenerationTest {

@Test
public void testTextGeneration() throws TranslateException, ModelException, IOException {
TestRequirements.nightly();
TestRequirements.engine("PyTorch");
// TestRequirements.nightly();
// TestRequirements.engine("PyTorch");

String expected =
"DeepMind Company is a global leader in the field of artificial"
Expand Down

0 comments on commit 6511fce

Please sign in to comment.