Skip to content

Commit

Permalink
doc
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jun 23, 2023
1 parent 9fa5bb7 commit a071e0b
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

// BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration
// of the
// autoregressive loop.
// It is a struct consisting of NDArrays, whose first dimension is batch, and also contains
// sequence dimension (whose position in tensor's shape is specified by seqDimOrder).
// The SeqBatcher batch operations will operate on these two dimensions.
/**
* BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration
* of the autoregressive loop It is a struct consisting of NDArrays, whose first dimension is batch,
* and also contains sequence dimension (whose position in tensor's shape is specified by seqDimOrder).
* The SeqBatcher batch operations will operate on these two dimensions.
*/
public abstract class BatchTensorList {
// [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastOutputIds;
Expand All @@ -39,13 +38,29 @@ public abstract class BatchTensorList {

BatchTensorList() {}

/**
* Constructs a BatchTensorList.
*
* @param list the NDList that contains the serialized version of the batch tensors
* @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
* is in a tensor's shape
*/
BatchTensorList(NDList list, long[] seqDimOrder) {
this.seqDimOrder = seqDimOrder;
pastOutputIds = list.get(0);
pastAttentionMask = list.get(1);
pastKeyValues = list.subNDList(2);
}

/**
* Constructs a BatchTensorList.
*
* @param pastOutputIds past output token ids
* @param pastAttentionMask past attention mask
* @param pastKeyValues past kv cache
* @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
* is in a tensor's shape
*/
BatchTensorList(
NDArray pastOutputIds,
NDArray pastAttentionMask,
Expand All @@ -57,11 +72,32 @@ public abstract class BatchTensorList {
this.seqDimOrder = seqDimOrder;
}

/**
* Construct a BatchTensorList from the serialized version of the batch tensors. The
* pastOutputIds has to be the first in the output list.
*
* @param inputList the serialized version of the batch tensors
* @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
* is in a tensor's shape
* @return BatchTensorList
*/
public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder);

// The pastOutputIds has to be the first in the output list
/**
* Gets the serialized version of the BatchTensorList. The pastOutputIds has to be the first in
* the output list.
*
* @return the NDList that contains the serialized BatchTensorList
*/
public abstract NDList getList();

/**
* Gets the sequence dimension order which specifies where the sequence dimension is in a
* tensor's shape.
*
* @return the sequence dimension order which specifies where the sequence dimension is in a
* tensor's shape
*/
public long[] getSeqDimOrder() {
return seqDimOrder;
}
Expand All @@ -75,6 +111,11 @@ public NDArray getPastOutputIds() {
return pastOutputIds;
}

/**
* Sets the past output token ids.
*
* @param pastOutputIds the past output token ids
*/
public void setPastOutputIds(NDArray pastOutputIds) {
this.pastOutputIds = pastOutputIds;
}
Expand All @@ -88,6 +129,11 @@ public NDArray getPastAttentionMask() {
return pastAttentionMask;
}

/**
* Sets the attention mask.
*
* @param pastAttentionMask the attention mask
*/
public void setPastAttentionMask(NDArray pastAttentionMask) {
this.pastAttentionMask = pastAttentionMask;
}
Expand All @@ -101,10 +147,22 @@ public NDList getPastKeyValues() {
return pastKeyValues;
}

/**
* Sets the kv cache.
*
* @param pastKeyValues the kv cache
*/
public void setPastKeyValues(NDList pastKeyValues) {
this.pastKeyValues = pastKeyValues;
}

/**
* Sets the sequence dimension order which specifies where the sequence dimension is in a
* tensor's shape.
*
* @param seqDimOrder the sequence dimension order which specifies where the sequence dimension
* is in a tensor's shape
*/
public void setSeqDimOrder(long[] seqDimOrder) {
this.seqDimOrder = seqDimOrder;
}
Expand Down
18 changes: 18 additions & 0 deletions api/src/main/java/ai/djl/modality/nlp/generate/CausalLMOutput.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,24 @@ public class CausalLMOutput {
// The cache of past sequence. seq-dim-size == |seq_past| + |inputIds|
private NDList pastKeyValuesList;

/**
* Construct the CausalLMOutput.
*
* @param logits the logits NDArray
* @param pastKeyValues the key-value cache
*/
public CausalLMOutput(NDArray logits, NDList pastKeyValues) {
this.logits = logits;
this.pastKeyValuesList = pastKeyValues;
}

/**
* Construct the CausalLMOutput.
*
* @param logits the logits NDArray
* @param hiddenState the first layer hiddenStates used as word embedding
* @param pastKeyValueList the key-value cache
*/
public CausalLMOutput(NDArray logits, NDArray hiddenState, NDList pastKeyValueList) {
this.logits = logits;
this.pastKeyValuesList = pastKeyValueList;
Expand All @@ -52,6 +65,11 @@ public NDArray getLogits() {
return logits;
}

/**
* Sets the value of the logits.
*
* @param logits value of logits NDArray
*/
public void setLogits(NDArray logits) {
this.logits = logits;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,24 @@
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* {@code ContrastiveSeqBatchScheduler} is a class which implements the contrastive search algorithm
* used in SeqBatchScheduler.
*/
public class ContrastiveSeqBatchScheduler extends SeqBatchScheduler {

/**
* Construct a ContrastiveSeqBatchScheduler.
*
* @param lmBlock the predictor containing language model
* @param config the autoregressive search configuration
*/
public ContrastiveSeqBatchScheduler(
Predictor<NDList, CausalLMOutput> lmBlock, SearchConfig config) {
super(lmBlock, config);
}

/** {@inheritDoc} */
@Override
public SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws TranslateException {
try (NDScope scope = new NDScope()) {
Expand Down Expand Up @@ -72,6 +83,7 @@ public SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws Transl
}
}

/** {@inheritDoc} */
@Override
public NDArray inferenceCall() throws TranslateException {
NDArray outputIds;
Expand Down
36 changes: 35 additions & 1 deletion api/src/main/java/ai/djl/modality/nlp/generate/SearchConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
*/
package ai.djl.modality.nlp.generate;

/**
* {@code SearchConfig} is a class whose fields are parameters used for autoregressive search / text
* generation.
*/
public class SearchConfig {

private int k;
Expand All @@ -30,7 +34,7 @@ public SearchConfig() {
this.maxSeqLength = 30;
this.eosTokenId = 50256;
this.padTokenId = 50256;
this.suffixPadding = true;
this.suffixPadding = false;
}

/**
Expand All @@ -42,6 +46,11 @@ public int getK() {
return k;
}

/**
* Sets the value for the topk choice.
*
* @param k the value for topk choice
*/
public void setK(int k) {
this.k = k;
}
Expand All @@ -55,6 +64,11 @@ public float getAlpha() {
return alpha;
}

/**
* Sets the value of alpha the penalty for repetition.
*
* @param alpha the value of the penalty for repetition
*/
public void setAlpha(float alpha) {
this.alpha = alpha;
}
Expand All @@ -68,6 +82,11 @@ public int getBeam() {
return beam;
}

/**
* Sets the value of beam size.
*
* @param beam the value of beam size
*/
public void setBeam(int beam) {
this.beam = beam;
}
Expand All @@ -81,6 +100,11 @@ public int getMaxSeqLength() {
return maxSeqLength;
}

/**
* Sets the value of max sequence length.
*
* @param maxSeqLength the value max sequence length
*/
public void setMaxSeqLength(int maxSeqLength) {
this.maxSeqLength = maxSeqLength;
}
Expand All @@ -94,6 +118,11 @@ public long getPadTokenId() {
return padTokenId;
}

/**
* Sets the value of padTokenId.
*
* @param padTokenId the token id for padding
*/
public void setPadTokenId(long padTokenId) {
this.padTokenId = padTokenId;
}
Expand All @@ -116,6 +145,11 @@ public boolean isSuffixPadding() {
return suffixPadding;
}

/**
* Sets the value of suffixPadding or rightPadding.
*
* @param suffixPadding whether the padding is from right
*/
public void setSuffixPadding(boolean suffixPadding) {
this.suffixPadding = suffixPadding;
}
Expand Down
Loading

0 comments on commit a071e0b

Please sign in to comment.