Skip to content

Commit

Permalink
doc
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jun 22, 2023
1 parent b3709c9 commit 39a6db1
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* {@code ContrastiveSeqBatchScheduler} is a class which implements the contrastive search algorithm
* used in SeqBatchScheduler.
*/
public class ContrastiveSeqBatchScheduler extends SeqBatchScheduler {

public ContrastiveSeqBatchScheduler(
Expand Down
36 changes: 35 additions & 1 deletion api/src/main/java/ai/djl/modality/nlp/generate/SearchConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
*/
package ai.djl.modality.nlp.generate;

/**
* {@code SearchConfig} is a class whose fields are parameters used for autoregressive search / text
* generation.
*/
public class SearchConfig {

private int k;
Expand All @@ -30,7 +34,7 @@ public SearchConfig() {
this.maxSeqLength = 30;
this.eosTokenId = 50256;
this.padTokenId = 50256;
this.suffixPadding = true;
this.suffixPadding = false;
}

/**
Expand All @@ -42,6 +46,11 @@ public int getK() {
return k;
}

/**
* Sets the value for the topk choice.
*
* @param k the value for topk choice
*/
public void setK(int k) {
this.k = k;
}
Expand All @@ -55,6 +64,11 @@ public float getAlpha() {
return alpha;
}

/**
* Sets the value of alpha the penalty for repetition.
*
* @param alpha the value of the penalty for repetition
*/
public void setAlpha(float alpha) {
this.alpha = alpha;
}
Expand All @@ -68,6 +82,11 @@ public int getBeam() {
return beam;
}

/**
* Sets the value of beam size.
*
* @param beam the value of beam size
*/
public void setBeam(int beam) {
this.beam = beam;
}
Expand All @@ -81,6 +100,11 @@ public int getMaxSeqLength() {
return maxSeqLength;
}

/**
* Sets the value of max sequence length.
*
* @param maxSeqLength the value max sequence length
*/
public void setMaxSeqLength(int maxSeqLength) {
this.maxSeqLength = maxSeqLength;
}
Expand All @@ -94,6 +118,11 @@ public long getPadTokenId() {
return padTokenId;
}

/**
* Sets the value of padTokenId.
*
* @param padTokenId the token id for padding
*/
public void setPadTokenId(long padTokenId) {
this.padTokenId = padTokenId;
}
Expand All @@ -116,6 +145,11 @@ public boolean isSuffixPadding() {
return suffixPadding;
}

/**
* Sets the value of suffixPadding or rightPadding.
*
* @param suffixPadding whether the padding is from right
*/
public void setSuffixPadding(boolean suffixPadding) {
this.suffixPadding = suffixPadding;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public SeqBatchScheduler(Predictor<NDList, CausalLMOutput> lmBlock, SearchConfig
* @param inputIds the input token ids.
* @param batchUids the request uid identifying a sequence
* @return SeqBatcher Stores the search state and operate on the BatchTensorList
* @throws TranslateException if forward fails
*/
public abstract SeqBatcher initForward(NDArray inputIds, NDArray batchUids)
throws TranslateException;
Expand All @@ -65,6 +66,7 @@ public abstract SeqBatcher initForward(NDArray inputIds, NDArray batchUids)
*
* @param count the time of forward calls
* @return boolean Indicate whether the Batch is empty
* @throws TranslateException if forward fails
*/
public boolean incrementForward(int count) throws TranslateException {
int i = 0;
Expand All @@ -85,12 +87,16 @@ public boolean incrementForward(int count) throws TranslateException {
return false;
}

/**
* An inference call in an iteration
*/
/** An inference call in an iteration */
abstract NDArray inferenceCall() throws TranslateException;

/** Add new batch. */
/**
* Add new batch.
*
* @param inputIds the input token ids.
* @param batchUids the request uid identifying a sequence
* @throws TranslateException if forward fails
*/
public void addRequest(NDArray inputIds, NDArray batchUids) throws TranslateException {
SeqBatcher seqBatcherNew = initForward(inputIds, batchUids);
if (seqBatcher == null) {
Expand Down Expand Up @@ -180,7 +186,8 @@ static NDArray computeAttentionMask(NDArray inputIds, SearchConfig config) {
* @param inputIds input token ids
* @param offSets the offset
* @param pastSeqLength past sequence length
* @param repeat the number of repeats used in interleave-repeating the position_ids to multiple rows
* @param repeat the number of repeats used in interleave-repeating the position_ids to multiple
* rows
* @return the position ids NDArray
*/
static NDArray computePositionIds(
Expand Down
22 changes: 17 additions & 5 deletions api/src/main/java/ai/djl/modality/nlp/generate/SeqBatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

// This stores the search state (BatchTensorList), the control variables (e.g. seqLength, offSets,
// etc), and batch operations (merge, trim, exitCriteria, etc) on BatchTensorList.
/**
* {@code SeqBatcher} stores the search state (BatchTensorList), the control variables (e.g.
* seqLength, offSets, etc), and batch operations (merge, trim, exitCriteria, etc) on
* BatchTensorList.
*/
public class SeqBatcher {
NDManager manager;

Expand Down Expand Up @@ -56,11 +59,20 @@ public class SeqBatcher {
exitIndexEndPosition = new ConcurrentHashMap<>();
}

/**
* Get the batch data which is stored as a {@code BatchTensorList}.
*
* @return the batch data stored as BatchTensorList
*/
public BatchTensorList getData() {
return data;
}

/** Add new batch. Modify the batch dimension and the left padding. */
/**
* Add new batch. Modify the batch dimension and the left padding.
*
* @param seqBatcherNew the seqBatcher to add.
*/
public void addBatch(SeqBatcher seqBatcherNew) {
merge(this, seqBatcherNew, seqLength - seqBatcherNew.seqLength);
// manager and finishedSequences stay the same;
Expand Down Expand Up @@ -166,7 +178,8 @@ public void exitCriteria(NDArray outputIds, long maxLength, long eosTokenId) {
}
}

/** Collect the finished sequences and trim the left padding.
/**
* Collect the finished sequences and trim the left padding.
*
* @return a map that stores request id to output token ids
*/
Expand Down Expand Up @@ -268,7 +281,6 @@ public Map<Long, NDArray> collectAndTrim() {
}
}


/**
* Compute the position ids by linear search from the left
*
Expand Down
12 changes: 12 additions & 0 deletions api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* {@code TextGenerator} is an LMSearch (language model search) which contains multiple
* autoregressive search methods. It has a Predictor<NDList, CausalLMOutput>, which is called inside
* an autoregressive inference loop.
*/
public class TextGenerator {

private String searchName;
Expand All @@ -42,6 +47,13 @@ public TextGenerator(
this.config = searchConfig;
}

/**
* Greedy search.
*
* @param inputIds the input token ids.
* @return the output token ids stored as NDArray
* @throws TranslateException if forward fails
*/
@SuppressWarnings("try")
public NDArray greedySearch(NDArray inputIds) throws TranslateException {
NDArray attentionMask = prepareAttentionMaskOffset(inputIds, config);
Expand Down

0 comments on commit 39a6db1

Please sign in to comment.