-
Notifications
You must be signed in to change notification settings - Fork 658
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Batch_scheduler greedy_and_beam constrastiveSearch LLMDecoder front_e…
…nd_translator GPT2PT.merge
- Loading branch information
Showing
42 changed files
with
3,169 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
99 changes: 99 additions & 0 deletions
99
api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
package ai.djl.modality.nlp.generate; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDList; | ||
|
||
// BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration | ||
// of the | ||
// autoregressive loop. | ||
// It is a struct consisting of NDArrays, whose first dimension is batch, and also contains | ||
// sequence dimension (whose position in tensor's shape is specified by seqDimOrder). | ||
// The SeqBatcher batch operations will operate on these two dimensions. | ||
public abstract class BatchTensorList { | ||
// [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDArray pastOutputIds; | ||
|
||
// [batch, seq_past] | ||
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDArray pastAttentionMask; | ||
|
||
// (k, v) * numLayer, | ||
// kv: [batch, heads, seq_past, kvfeature] | ||
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDList pastKeyValues; | ||
|
||
// Sequence dimension order among all dimensions for each element in the batch list. | ||
private long[] seqDimOrder; | ||
|
||
BatchTensorList() {} | ||
|
||
BatchTensorList(NDList list, long[] seqDimOrder) { | ||
this.seqDimOrder = seqDimOrder; | ||
pastOutputIds = list.get(0); | ||
pastAttentionMask = list.get(1); | ||
pastKeyValues = list.subNDList(2); | ||
} | ||
|
||
BatchTensorList( | ||
NDArray pastOutputIds, | ||
NDArray pastAttentionMask, | ||
NDList pastKeyValues, | ||
long[] seqDimOrder) { | ||
this.pastKeyValues = pastKeyValues; | ||
this.pastOutputIds = pastOutputIds; | ||
this.pastAttentionMask = pastAttentionMask; | ||
this.seqDimOrder = seqDimOrder; | ||
} | ||
|
||
public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder); | ||
|
||
// The pastOutputIds has to be the first in the output list | ||
public abstract NDList getList(); | ||
|
||
public long[] getSeqDimOrder() { | ||
return seqDimOrder; | ||
} | ||
|
||
/** | ||
* Gets the value of the pastOutputIds. | ||
* | ||
* @return the value of pastOutputIds | ||
*/ | ||
public NDArray getPastOutputIds() { | ||
return pastOutputIds; | ||
} | ||
|
||
public void setPastOutputIds(NDArray pastOutputIds) { | ||
this.pastOutputIds = pastOutputIds; | ||
} | ||
|
||
/** | ||
* Gets the value of the pastAttentionMask. | ||
* | ||
* @return the value of pastAttentionMask | ||
*/ | ||
public NDArray getPastAttentionMask() { | ||
return pastAttentionMask; | ||
} | ||
|
||
public void setPastAttentionMask(NDArray pastAttentionMask) { | ||
this.pastAttentionMask = pastAttentionMask; | ||
} | ||
|
||
/** | ||
* Gets the value of the pastKeyValues. | ||
* | ||
* @return the value of pastKeyValues | ||
*/ | ||
public NDList getPastKeyValues() { | ||
return pastKeyValues; | ||
} | ||
|
||
public void setPastKeyValues(NDList pastKeyValues) { | ||
this.pastKeyValues = pastKeyValues; | ||
} | ||
|
||
public void setSeqDimOrder(long[] seqDimOrder) { | ||
this.seqDimOrder = seqDimOrder; | ||
} | ||
} |
122 changes: 122 additions & 0 deletions
122
api/src/main/java/ai/djl/modality/nlp/generate/BeamBatchTensorList.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
package ai.djl.modality.nlp.generate; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDList; | ||
|
||
class BeamBatchTensorList extends BatchTensorList { | ||
// [batch, beam, seq=1] | ||
private NDArray nextInputIds; | ||
|
||
// [batch, beam] | ||
private NDArray lastProbs; | ||
|
||
// [batch, beam, seq_past + new_seq] | ||
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDArray pastAttentionMask; | ||
|
||
/* Variables below are one time step behind the above state variables. Ie, they contain all the past sequence but excludes the time step that corresponds to the above input. */ | ||
|
||
// [batch, beam, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDArray pastOutputIds; | ||
|
||
// (k, v) * numLayer, | ||
// kv: [batch, beam, heads, seq_past, kvfeature] | ||
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDList pastKeyValues; | ||
|
||
BeamBatchTensorList() {} | ||
|
||
BeamBatchTensorList( | ||
NDArray nextInputIds, | ||
NDArray pastOutputIds, | ||
NDList pastKeyValues, | ||
NDArray pastAttentionMask, | ||
NDArray lastProb) { | ||
this.nextInputIds = nextInputIds; | ||
this.pastKeyValues = pastKeyValues; | ||
this.pastOutputIds = pastOutputIds; | ||
this.pastAttentionMask = pastAttentionMask; | ||
this.lastProbs = lastProb; | ||
} | ||
|
||
@Override | ||
public BatchTensorList fromList(NDList inputList, long[] seqDimOrder) { | ||
return new BeamBatchTensorList(); | ||
} | ||
|
||
@Override | ||
public NDList getList() { | ||
return new NDList(); | ||
} | ||
|
||
/** | ||
* Gets the value of the nextInputIds. | ||
* | ||
* @return the value of nextInputIds | ||
*/ | ||
public NDArray getNextInputIds() { | ||
return nextInputIds; | ||
} | ||
|
||
public void setNextInputIds(NDArray nextInputIds) { | ||
this.nextInputIds = nextInputIds; | ||
} | ||
|
||
/** | ||
* Gets the value of the lastProbs. | ||
* | ||
* @return the value of lastProbs | ||
*/ | ||
public NDArray getLastProbs() { | ||
return lastProbs; | ||
} | ||
|
||
public void setLastProbs(NDArray lastProbs) { | ||
this.lastProbs = lastProbs; | ||
} | ||
|
||
/** | ||
* Gets the value of the pastAttentionMask. | ||
* | ||
* @return the value of pastAttentionMask | ||
*/ | ||
@Override | ||
public NDArray getPastAttentionMask() { | ||
return pastAttentionMask; | ||
} | ||
|
||
@Override | ||
public void setPastAttentionMask(NDArray pastAttentionMask) { | ||
this.pastAttentionMask = pastAttentionMask; | ||
} | ||
|
||
/** | ||
* Gets the value of the pastOutputIds. | ||
* | ||
* @return the value of pastOutputIds | ||
*/ | ||
@Override | ||
public NDArray getPastOutputIds() { | ||
return pastOutputIds; | ||
} | ||
|
||
@Override | ||
public void setPastOutputIds(NDArray pastOutputIds) { | ||
this.pastOutputIds = pastOutputIds; | ||
} | ||
|
||
/** | ||
* Gets the value of the pastKeyValues. | ||
* | ||
* @return the value of pastKeyValues | ||
*/ | ||
@Override | ||
public NDList getPastKeyValues() { | ||
return pastKeyValues; | ||
} | ||
|
||
@Override | ||
public void setPastKeyValues(NDList pastKeyValues) { | ||
this.pastKeyValues = pastKeyValues; | ||
} | ||
} |
64 changes: 64 additions & 0 deletions
64
api/src/main/java/ai/djl/modality/nlp/generate/CausalLMOutput.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
package ai.djl.modality.nlp.generate; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDList; | ||
|
||
/** CausalLMOuput is used to contain multiple output of a language model. */ | ||
public class CausalLMOutput { | ||
|
||
// [batch, seq, feature] | ||
// The prob. conditional on a sequence that ends at an element in seq-dim. seq-dim-size = | ||
// |inputIds| | ||
private NDArray logits; | ||
|
||
// [batch, seq, dim] * (layers+1) -> take -1 | ||
// The vec. rep. of a sequence that ends at an element in seq-dim. seq-dim-size = |inputIds| | ||
private NDArray hiddenStates; | ||
|
||
// (k, v) * numLayer, | ||
// kv: [batch, heads, seq_past, feature] | ||
// The cache of past sequence. seq-dim-size == |seq_past| + |inputIds| | ||
private NDList pastKeyValuesList; | ||
|
||
public CausalLMOutput(NDArray logits, NDList pastKeyValues) { | ||
this.logits = logits; | ||
this.pastKeyValuesList = pastKeyValues; | ||
} | ||
|
||
public CausalLMOutput(NDArray logits, NDArray hiddenState, NDList pastKeyValueList) { | ||
this.logits = logits; | ||
this.pastKeyValuesList = pastKeyValueList; | ||
this.hiddenStates = hiddenState; | ||
} | ||
|
||
/** | ||
* Gets the value of the logits. | ||
* | ||
* @return the value of logits | ||
*/ | ||
public NDArray getLogits() { | ||
return logits; | ||
} | ||
|
||
public void setLogits(NDArray logits) { | ||
this.logits = logits; | ||
} | ||
|
||
/** | ||
* Gets the value of the allHiddenStates. | ||
* | ||
* @return the value of allHiddenStates | ||
*/ | ||
public NDArray getHiddenState() { | ||
return hiddenStates; | ||
} | ||
|
||
/** | ||
* Gets the value of the pastKeyValuesList. | ||
* | ||
* @return the value of pastKeyValuesList | ||
*/ | ||
public NDList getPastKeyValuesList() { | ||
return pastKeyValuesList; | ||
} | ||
} |
Oops, something went wrong.