Skip to content

Commit

Permalink
doc
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jun 21, 2023
1 parent 9fa5bb7 commit b3709c9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,18 @@ public SeqBatchScheduler(Predictor<NDList, CausalLMOutput> lmBlock, SearchConfig
/**
* Initialize the iteration and SeqBatcher
*
* @return SeqBatcher. Stores the search state and operate on the BatchTensorList.
* @param inputIds the input token ids.
* @param batchUids the request uid identifying a sequence
* @return SeqBatcher Stores the search state and operate on the BatchTensorList
*/
public abstract SeqBatcher initForward(NDArray inputIds, NDArray batchUids)
throws TranslateException;

/**
* Go forward for a given number of iterations.
*
* @return boolean. Indicate whether the Batch is empty.
* @param count the time of forward calls
* @return boolean Indicate whether the Batch is empty
*/
public boolean incrementForward(int count) throws TranslateException {
int i = 0;
Expand All @@ -82,6 +85,9 @@ public boolean incrementForward(int count) throws TranslateException {
return false;
}

/**
* An inference call in an iteration
*/
abstract NDArray inferenceCall() throws TranslateException;

/** Add new batch. */
Expand All @@ -94,13 +100,24 @@ public void addRequest(NDArray inputIds, NDArray batchUids) throws TranslateExce
}
}

/** Collect finished results. */
/**
* Collect finished results.
*
* @return the outputs stored as a map from requestUid to output token ids
*/
public Map<Long, NDArray> collectResults() {
Map<Long, NDArray> output = results;
results = new ConcurrentHashMap<>();
return output;
}

/**
* Compute the offSets by linear search from the left
*
* @param inputIds input token ids
* @param config search configuration
* @return the offsets NDArray
*/
static NDArray computeOffSets(NDArray inputIds, SearchConfig config) {
int numBatch = Math.toIntExact(inputIds.getShape().get(0));
int initSeqSize = Math.toIntExact(inputIds.getShape().get(1));
Expand All @@ -123,6 +140,13 @@ static NDArray computeOffSets(NDArray inputIds, SearchConfig config) {
return manager.create(offSetsArray).reshape(-1, 1);
}

/**
* Compute the attention mask by linear search from the left
*
* @param inputIds input token ids
* @param config search configuration
* @return the attention mask NDArray
*/
static NDArray computeAttentionMask(NDArray inputIds, SearchConfig config) {
int numBatch = Math.toIntExact(inputIds.getShape().get(0));
int initSeqSize = Math.toIntExact(inputIds.getShape().get(1));
Expand Down Expand Up @@ -150,6 +174,15 @@ static NDArray computeAttentionMask(NDArray inputIds, SearchConfig config) {
return attentionMask;
}

/**
* Compute the position ids by linear search from the left
*
* @param inputIds input token ids
* @param offSets the offset
* @param pastSeqLength past sequence length
* @param repeat the number of repeats used in interleave-repeating the position_ids to multiple rows
* @return the position ids NDArray
*/
static NDArray computePositionIds(
NDArray inputIds, NDArray offSets, long pastSeqLength, int repeat) {
NDManager manager = inputIds.getManager();
Expand Down
15 changes: 14 additions & 1 deletion api/src/main/java/ai/djl/modality/nlp/generate/SeqBatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ private void merge(SeqBatcher seqBatcher1, SeqBatcher seqBatcher2, long seqDelta
/**
* Check which batch needs to exit, according certain criteria like EOS or maxLength. It is an
* iteration over batch and is thus also considered as batch operation.
*
* @param outputIds output token ids in an incremental forward call
* @param maxLength max total sequence length
* @param eosTokenId end of sentence token id
*/
public void exitCriteria(NDArray outputIds, long maxLength, long eosTokenId) {
long[] outputIdsArray = outputIds.toLongArray();
Expand All @@ -162,7 +166,10 @@ public void exitCriteria(NDArray outputIds, long maxLength, long eosTokenId) {
}
}

/** Collect the finished sequences and trim the left padding. */
/** Collect the finished sequences and trim the left padding.
*
* @return a map that stores request id to output token ids
*/
public Map<Long, NDArray> collectAndTrim() {
if (exitIndexEndPosition.isEmpty()) {
return new ConcurrentHashMap<>();
Expand Down Expand Up @@ -261,6 +268,12 @@ public Map<Long, NDArray> collectAndTrim() {
}
}


/**
* Compute the position ids by linear search from the left
*
* @return the boolean indicating whether all sequences are empty
*/
public boolean sequenceComplete() {
return !exitIndexEndPosition.isEmpty();
}
Expand Down

0 comments on commit b3709c9

Please sign in to comment.