Skip to content

Commit

Permalink
Batch_scheduler greedy_and_beam constrastiveSearch LLMDecoder front_e…
Browse files Browse the repository at this point in the history
…nd_translator GPT2PT.merge
  • Loading branch information
KexinFeng committed Jun 8, 2023
1 parent 59d57f8 commit cf4baa9
Show file tree
Hide file tree
Showing 42 changed files with 3,169 additions and 29 deletions.
4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ public static Application of(String path) {
return NLP.TOKEN_CLASSIFICATION;
case "nlp/word_embedding":
return NLP.WORD_EMBEDDING;
case "nlp/text_generation":
return NLP.TEXT_GENERATION;
case "tabular":
return Tabular.ANY;
case "tabular/linear_regression":
Expand Down Expand Up @@ -261,6 +263,8 @@ public interface NLP {
*/
Application WORD_EMBEDDING = new Application("nlp/word_embedding");

Application TEXT_GENERATION = new Application("nlp/text_generation");

/**
* An application that translates text from one language to another.
*
Expand Down
7 changes: 7 additions & 0 deletions api/src/main/java/ai/djl/engine/Engine.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.modality.nlp.generate.GPTConfig;
import ai.djl.modality.nlp.generate.LMBlock;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Block;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.GradientCollector;
import ai.djl.training.LocalParameterServer;
Expand Down Expand Up @@ -302,6 +305,10 @@ public SymbolBlock newSymbolBlock(NDManager manager) {
*/
public abstract NDManager newBaseManager(Device device);

public LMBlock newLMBlock(String languageModel, GPTConfig gptConfig, Block[] blocks) {
throw new UnsupportedOperationException("Not supported.");
}

/**
* Returns a new instance of {@link GradientCollector}.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

// BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration
// of the
// autoregressive loop.
// It is a struct consisting of NDArrays, whose first dimension is batch, and also contains
// sequence dimension (whose position in tensor's shape is specified by seqDimOrder).
// The SeqBatcher batch operations will operate on these two dimensions.
public abstract class BatchTensorList {
// [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastOutputIds;

// [batch, seq_past]
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastAttentionMask;

// (k, v) * numLayer,
// kv: [batch, heads, seq_past, kvfeature]
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDList pastKeyValues;

// Sequence dimension order among all dimensions for each element in the batch list.
private long[] seqDimOrder;

BatchTensorList() {}

BatchTensorList(NDList list, long[] seqDimOrder) {
this.seqDimOrder = seqDimOrder;
pastOutputIds = list.get(0);
pastAttentionMask = list.get(1);
pastKeyValues = list.subNDList(2);
}

BatchTensorList(
NDArray pastOutputIds,
NDArray pastAttentionMask,
NDList pastKeyValues,
long[] seqDimOrder) {
this.pastKeyValues = pastKeyValues;
this.pastOutputIds = pastOutputIds;
this.pastAttentionMask = pastAttentionMask;
this.seqDimOrder = seqDimOrder;
}

public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder);

// The pastOutputIds has to be the first in the output list
public abstract NDList getList();

public long[] getSeqDimOrder() {
return seqDimOrder;
}

/**
* Gets the value of the pastOutputIds.
*
* @return the value of pastOutputIds
*/
public NDArray getPastOutputIds() {
return pastOutputIds;
}

public void setPastOutputIds(NDArray pastOutputIds) {
this.pastOutputIds = pastOutputIds;
}

/**
* Gets the value of the pastAttentionMask.
*
* @return the value of pastAttentionMask
*/
public NDArray getPastAttentionMask() {
return pastAttentionMask;
}

public void setPastAttentionMask(NDArray pastAttentionMask) {
this.pastAttentionMask = pastAttentionMask;
}

/**
* Gets the value of the pastKeyValues.
*
* @return the value of pastKeyValues
*/
public NDList getPastKeyValues() {
return pastKeyValues;
}

public void setPastKeyValues(NDList pastKeyValues) {
this.pastKeyValues = pastKeyValues;
}

public void setSeqDimOrder(long[] seqDimOrder) {
this.seqDimOrder = seqDimOrder;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

class BeamBatchTensorList extends BatchTensorList {
// [batch, beam, seq=1]
private NDArray nextInputIds;

// [batch, beam]
private NDArray lastProbs;

// [batch, beam, seq_past + new_seq]
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastAttentionMask;

/* Variables below are one time step behind the above state variables. Ie, they contain all the past sequence but excludes the time step that corresponds to the above input. */

// [batch, beam, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastOutputIds;

// (k, v) * numLayer,
// kv: [batch, beam, heads, seq_past, kvfeature]
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDList pastKeyValues;

BeamBatchTensorList() {}

BeamBatchTensorList(
NDArray nextInputIds,
NDArray pastOutputIds,
NDList pastKeyValues,
NDArray pastAttentionMask,
NDArray lastProb) {
this.nextInputIds = nextInputIds;
this.pastKeyValues = pastKeyValues;
this.pastOutputIds = pastOutputIds;
this.pastAttentionMask = pastAttentionMask;
this.lastProbs = lastProb;
}

@Override
public BatchTensorList fromList(NDList inputList, long[] seqDimOrder) {
return new BeamBatchTensorList();
}

@Override
public NDList getList() {
return new NDList();
}

/**
* Gets the value of the nextInputIds.
*
* @return the value of nextInputIds
*/
public NDArray getNextInputIds() {
return nextInputIds;
}

public void setNextInputIds(NDArray nextInputIds) {
this.nextInputIds = nextInputIds;
}

/**
* Gets the value of the lastProbs.
*
* @return the value of lastProbs
*/
public NDArray getLastProbs() {
return lastProbs;
}

public void setLastProbs(NDArray lastProbs) {
this.lastProbs = lastProbs;
}

/**
* Gets the value of the pastAttentionMask.
*
* @return the value of pastAttentionMask
*/
@Override
public NDArray getPastAttentionMask() {
return pastAttentionMask;
}

@Override
public void setPastAttentionMask(NDArray pastAttentionMask) {
this.pastAttentionMask = pastAttentionMask;
}

/**
* Gets the value of the pastOutputIds.
*
* @return the value of pastOutputIds
*/
@Override
public NDArray getPastOutputIds() {
return pastOutputIds;
}

@Override
public void setPastOutputIds(NDArray pastOutputIds) {
this.pastOutputIds = pastOutputIds;
}

/**
* Gets the value of the pastKeyValues.
*
* @return the value of pastKeyValues
*/
@Override
public NDList getPastKeyValues() {
return pastKeyValues;
}

@Override
public void setPastKeyValues(NDList pastKeyValues) {
this.pastKeyValues = pastKeyValues;
}
}
64 changes: 64 additions & 0 deletions api/src/main/java/ai/djl/modality/nlp/generate/CausalLMOutput.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/** CausalLMOuput is used to contain multiple output of a language model. */
public class CausalLMOutput {

// [batch, seq, feature]
// The prob. conditional on a sequence that ends at an element in seq-dim. seq-dim-size =
// |inputIds|
private NDArray logits;

// [batch, seq, dim] * (layers+1) -> take -1
// The vec. rep. of a sequence that ends at an element in seq-dim. seq-dim-size = |inputIds|
private NDArray hiddenStates;

// (k, v) * numLayer,
// kv: [batch, heads, seq_past, feature]
// The cache of past sequence. seq-dim-size == |seq_past| + |inputIds|
private NDList pastKeyValuesList;

public CausalLMOutput(NDArray logits, NDList pastKeyValues) {
this.logits = logits;
this.pastKeyValuesList = pastKeyValues;
}

public CausalLMOutput(NDArray logits, NDArray hiddenState, NDList pastKeyValueList) {
this.logits = logits;
this.pastKeyValuesList = pastKeyValueList;
this.hiddenStates = hiddenState;
}

/**
* Gets the value of the logits.
*
* @return the value of logits
*/
public NDArray getLogits() {
return logits;
}

public void setLogits(NDArray logits) {
this.logits = logits;
}

/**
* Gets the value of the allHiddenStates.
*
* @return the value of allHiddenStates
*/
public NDArray getHiddenState() {
return hiddenStates;
}

/**
* Gets the value of the pastKeyValuesList.
*
* @return the value of pastKeyValuesList
*/
public NDList getPastKeyValuesList() {
return pastKeyValuesList;
}
}
Loading

0 comments on commit cf4baa9

Please sign in to comment.