Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] implements text-generation search algorithm #2637

Merged
merged 10 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ public static Application of(String path) {
return NLP.TOKEN_CLASSIFICATION;
case "nlp/word_embedding":
return NLP.WORD_EMBEDDING;
case "nlp/text_generation":
return NLP.TEXT_GENERATION;
case "tabular":
return Tabular.ANY;
case "tabular/linear_regression":
Expand Down Expand Up @@ -261,6 +263,8 @@ public interface NLP {
*/
Application WORD_EMBEDDING = new Application("nlp/word_embedding");

Application TEXT_GENERATION = new Application("nlp/text_generation");

/**
* An application that translates text from one language to another.
*
Expand Down
111 changes: 111 additions & 0 deletions api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

// BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration
// of the
// autoregressive loop.
// It is a struct consisting of NDArrays, whose first dimension is batch, and also contains
// sequence dimension (whose position in tensor's shape is specified by seqDimOrder).
// The SeqBatcher batch operations will operate on these two dimensions.
public abstract class BatchTensorList {
// [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastOutputIds;

// [batch, seq_past]
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastAttentionMask;

// (k, v) * numLayer,
// kv: [batch, heads, seq_past, kvfeature]
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDList pastKeyValues;

// Sequence dimension order among all dimensions for each element in the batch list.
private long[] seqDimOrder;

BatchTensorList() {}

BatchTensorList(NDList list, long[] seqDimOrder) {
this.seqDimOrder = seqDimOrder;
pastOutputIds = list.get(0);
pastAttentionMask = list.get(1);
pastKeyValues = list.subNDList(2);
}

BatchTensorList(
NDArray pastOutputIds,
NDArray pastAttentionMask,
NDList pastKeyValues,
long[] seqDimOrder) {
this.pastKeyValues = pastKeyValues;
this.pastOutputIds = pastOutputIds;
this.pastAttentionMask = pastAttentionMask;
this.seqDimOrder = seqDimOrder;
}

public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder);

// The pastOutputIds has to be the first in the output list
public abstract NDList getList();

public long[] getSeqDimOrder() {
return seqDimOrder;
}

/**
* Gets the value of the pastOutputIds.
*
* @return the value of pastOutputIds
*/
public NDArray getPastOutputIds() {
return pastOutputIds;
}

public void setPastOutputIds(NDArray pastOutputIds) {
this.pastOutputIds = pastOutputIds;
}

/**
* Gets the value of the pastAttentionMask.
*
* @return the value of pastAttentionMask
*/
public NDArray getPastAttentionMask() {
return pastAttentionMask;
}

public void setPastAttentionMask(NDArray pastAttentionMask) {
this.pastAttentionMask = pastAttentionMask;
}

/**
* Gets the value of the pastKeyValues.
*
* @return the value of pastKeyValues
*/
public NDList getPastKeyValues() {
return pastKeyValues;
}

public void setPastKeyValues(NDList pastKeyValues) {
this.pastKeyValues = pastKeyValues;
}

public void setSeqDimOrder(long[] seqDimOrder) {
this.seqDimOrder = seqDimOrder;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

class BeamBatchTensorList extends BatchTensorList {
// [batch, beam, seq=1]
private NDArray nextInputIds;

// [batch, beam]
private NDArray lastProbs;

// [batch, beam, seq_past + new_seq]
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastAttentionMask;

/* Variables below are one time step behind the above state variables. Ie, they contain all the past sequence but excludes the time step that corresponds to the above input. */

// [batch, beam, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastOutputIds;

// (k, v) * numLayer,
// kv: [batch, beam, heads, seq_past, kvfeature]
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDList pastKeyValues;

BeamBatchTensorList() {}

BeamBatchTensorList(
NDArray nextInputIds,
NDArray pastOutputIds,
NDList pastKeyValues,
NDArray pastAttentionMask,
NDArray lastProb) {
this.nextInputIds = nextInputIds;
this.pastKeyValues = pastKeyValues;
this.pastOutputIds = pastOutputIds;
this.pastAttentionMask = pastAttentionMask;
this.lastProbs = lastProb;
}

@Override
public BatchTensorList fromList(NDList inputList, long[] seqDimOrder) {
return new BeamBatchTensorList();
}

@Override
public NDList getList() {
return new NDList();
}

/**
* Gets the value of the nextInputIds.
*
* @return the value of nextInputIds
*/
public NDArray getNextInputIds() {
return nextInputIds;
}

public void setNextInputIds(NDArray nextInputIds) {
this.nextInputIds = nextInputIds;
}

/**
* Gets the value of the lastProbs.
*
* @return the value of lastProbs
*/
public NDArray getLastProbs() {
return lastProbs;
}

public void setLastProbs(NDArray lastProbs) {
this.lastProbs = lastProbs;
}

/**
* Gets the value of the pastAttentionMask.
*
* @return the value of pastAttentionMask
*/
@Override
public NDArray getPastAttentionMask() {
return pastAttentionMask;
}

@Override
public void setPastAttentionMask(NDArray pastAttentionMask) {
this.pastAttentionMask = pastAttentionMask;
}

/**
* Gets the value of the pastOutputIds.
*
* @return the value of pastOutputIds
*/
@Override
public NDArray getPastOutputIds() {
return pastOutputIds;
}

@Override
public void setPastOutputIds(NDArray pastOutputIds) {
this.pastOutputIds = pastOutputIds;
}

/**
* Gets the value of the pastKeyValues.
*
* @return the value of pastKeyValues
*/
@Override
public NDList getPastKeyValues() {
return pastKeyValues;
}

@Override
public void setPastKeyValues(NDList pastKeyValues) {
this.pastKeyValues = pastKeyValues;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/** CausalLMOuput is used to contain multiple output of a language model. */
public class CausalLMOutput {

// [batch, seq, feature]
// The prob. conditional on a sequence that ends at an element in seq-dim. seq-dim-size =
// |inputIds|
private NDArray logits;

// [batch, seq, dim] * (layers+1) -> take -1
// The vec. rep. of a sequence that ends at an element in seq-dim. seq-dim-size = |inputIds|
private NDArray hiddenStates;

// (k, v) * numLayer,
// kv: [batch, heads, seq_past, feature]
// The cache of past sequence. seq-dim-size == |seq_past| + |inputIds|
private NDList pastKeyValuesList;

public CausalLMOutput(NDArray logits, NDList pastKeyValues) {
this.logits = logits;
this.pastKeyValuesList = pastKeyValues;
}

public CausalLMOutput(NDArray logits, NDArray hiddenState, NDList pastKeyValueList) {
this.logits = logits;
this.pastKeyValuesList = pastKeyValueList;
this.hiddenStates = hiddenState;
}

/**
* Gets the value of the logits.
*
* @return the value of logits
*/
public NDArray getLogits() {
return logits;
}

public void setLogits(NDArray logits) {
this.logits = logits;
}

/**
* Gets the value of the allHiddenStates.
*
* @return the value of allHiddenStates
*/
public NDArray getHiddenState() {
return hiddenStates;
}

/**
* Gets the value of the pastKeyValuesList.
*
* @return the value of pastKeyValuesList
*/
public NDList getPastKeyValuesList() {
return pastKeyValuesList;
}
}
Loading