Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] implements text-generation search algorithm #2637

Merged
merged 10 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ public static Application of(String path) {
return NLP.TOKEN_CLASSIFICATION;
case "nlp/word_embedding":
return NLP.WORD_EMBEDDING;
case "nlp/text_generation":
return NLP.TEXT_GENERATION;
case "tabular":
return Tabular.ANY;
case "tabular/linear_regression":
Expand Down Expand Up @@ -261,6 +263,8 @@ public interface NLP {
*/
Application WORD_EMBEDDING = new Application("nlp/word_embedding");

Application TEXT_GENERATION = new Application("nlp/text_generation");

/**
* An application that translates text from one language to another.
*
Expand Down
174 changes: 174 additions & 0 deletions api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/**
* BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration
* of the autoregressive loop.
*
* <p>It is a struct consisting of NDArrays, whose first dimension is batch, and also contains
* sequence dimension (whose position in tensor's shape is specified by seqDimOrder). The SeqBatcher
* batch operations will operate on these two dimensions.
*/
public abstract class BatchTensorList {
// [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastOutputIds;

// [batch, seq_past]
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastAttentionMask;

// (k, v) * numLayer,
// kv: [batch, heads, seq_past, kvfeature]
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDList pastKeyValues;

// Sequence dimension order among all dimensions for each element in the batch list.
private long[] seqDimOrder;

BatchTensorList() {}

/**
* Constructs a new {@code BatchTensorList} instance.
*
* @param list the NDList that contains the serialized version of the batch tensors
* @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
* is in a tensor's shape
*/
BatchTensorList(NDList list, long[] seqDimOrder) {
this.seqDimOrder = seqDimOrder;
pastOutputIds = list.get(0);
pastAttentionMask = list.get(1);
pastKeyValues = list.subNDList(2);
}

/**
* Constructs a new {@code BatchTensorList} instance.
*
* @param pastOutputIds past output token ids
* @param pastAttentionMask past attention mask
* @param pastKeyValues past kv cache
* @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
* is in a tensor's shape
*/
BatchTensorList(
NDArray pastOutputIds,
NDArray pastAttentionMask,
NDList pastKeyValues,
long[] seqDimOrder) {
this.pastKeyValues = pastKeyValues;
this.pastOutputIds = pastOutputIds;
this.pastAttentionMask = pastAttentionMask;
this.seqDimOrder = seqDimOrder;
}

/**
* Constructs a new {@code BatchTensorList} instance from the serialized version of the batch
* tensors.
*
* <p>The pastOutputIds has to be the first in the output list.
*
* @param inputList the serialized version of the batch tensors
* @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
* is in a tensor's shape
* @return BatchTensorList
*/
public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder);

/**
* Returns the serialized version of the BatchTensorList. The pastOutputIds has to be the first
* in the output list.
*
* @return the NDList that contains the serialized BatchTensorList
*/
public abstract NDList getList();

/**
* Returns the sequence dimension order which specifies where the sequence dimension is in a
* tensor's shape.
*
* @return the sequence dimension order which specifies where the sequence dimension is in a
* tensor's shape
*/
public long[] getSeqDimOrder() {
return seqDimOrder;
}

/**
* Returns the value of the pastOutputIds.
*
* @return the value of pastOutputIds
*/
public NDArray getPastOutputIds() {
return pastOutputIds;
}

/**
* Sets the past output token ids.
*
* @param pastOutputIds the past output token ids
*/
public void setPastOutputIds(NDArray pastOutputIds) {
this.pastOutputIds = pastOutputIds;
}

/**
* Returns the value of the pastAttentionMask.
*
* @return the value of pastAttentionMask
*/
public NDArray getPastAttentionMask() {
return pastAttentionMask;
}

/**
* Sets the attention mask.
*
* @param pastAttentionMask the attention mask
*/
public void setPastAttentionMask(NDArray pastAttentionMask) {
this.pastAttentionMask = pastAttentionMask;
}

/**
* Returns the value of the pastKeyValues.
*
* @return the value of pastKeyValues
*/
public NDList getPastKeyValues() {
return pastKeyValues;
}

/**
* Sets the kv cache.
*
* @param pastKeyValues the kv cache
*/
public void setPastKeyValues(NDList pastKeyValues) {
this.pastKeyValues = pastKeyValues;
}

/**
* Sets the sequence dimension order which specifies where the sequence dimension is in a
* tensor's shape.
*
* @param seqDimOrder the sequence dimension order which specifies where the sequence dimension
* is in a tensor's shape
*/
public void setSeqDimOrder(long[] seqDimOrder) {
this.seqDimOrder = seqDimOrder;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

class BeamBatchTensorList extends BatchTensorList {

// [batch, beam, seq=1]
private NDArray nextInputIds;

// [batch, beam]
private NDArray lastProbs;

// [batch, beam, seq_past + new_seq]
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastAttentionMask;

/* Variables below are one time step behind the above state variables. Ie, they contain all the past sequence but excludes the time step that corresponds to the above input. */

// [batch, beam, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastOutputIds;

// (k, v) * numLayer,
// kv: [batch, beam, heads, seq_past, kvfeature]
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDList pastKeyValues;

BeamBatchTensorList() {}

BeamBatchTensorList(
NDArray nextInputIds,
NDArray pastOutputIds,
NDList pastKeyValues,
NDArray pastAttentionMask,
NDArray lastProb) {
this.nextInputIds = nextInputIds;
this.pastKeyValues = pastKeyValues;
this.pastOutputIds = pastOutputIds;
this.pastAttentionMask = pastAttentionMask;
this.lastProbs = lastProb;
}

/** {@inheritDoc} */
@Override
public BatchTensorList fromList(NDList inputList, long[] seqDimOrder) {
return new BeamBatchTensorList();
}

/** {@inheritDoc} */
@Override
public NDList getList() {
return new NDList();
}

/**
* Returns the value of the nextInputIds.
*
* @return the value of nextInputIds
*/
public NDArray getNextInputIds() {
return nextInputIds;
}

public void setNextInputIds(NDArray nextInputIds) {
this.nextInputIds = nextInputIds;
}

/**
* Returns the value of the lastProbs.
*
* @return the value of lastProbs
*/
public NDArray getLastProbs() {
return lastProbs;
}

public void setLastProbs(NDArray lastProbs) {
this.lastProbs = lastProbs;
}

/** {@inheritDoc} */
@Override
public NDArray getPastAttentionMask() {
return pastAttentionMask;
}

/** {@inheritDoc} */
@Override
public void setPastAttentionMask(NDArray pastAttentionMask) {
this.pastAttentionMask = pastAttentionMask;
}

/** {@inheritDoc} */
@Override
public NDArray getPastOutputIds() {
return pastOutputIds;
}

/** {@inheritDoc} */
@Override
public void setPastOutputIds(NDArray pastOutputIds) {
this.pastOutputIds = pastOutputIds;
}

/** {@inheritDoc} */
@Override
public NDList getPastKeyValues() {
return pastKeyValues;
}

/** {@inheritDoc} */
@Override
public void setPastKeyValues(NDList pastKeyValues) {
this.pastKeyValues = pastKeyValues;
}
}
Loading