Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch the sequences with ContrastiveSeqBatchScheduler #2572

Closed
wants to merge 6 commits into from

Conversation

KexinFeng
Copy link
Contributor

@KexinFeng KexinFeng commented Apr 27, 2023

This PR succeeds PR #2547 #2509 #2557, which contains the benchmark outputs of the searching results.

The traced models can be downloaded here: https://d2l-java-resources.s3.amazonaws.com/tmp_model_unittest/model/nlp/text_generation/ai/djl/pytorch/gpt2/metadata.json and https://d2l-java-resources.s3.amazonaws.com/tmp_model_unittest/gpt2_decoder_model_merged.onnx.gz

Design

Overview

These PRs consists of the following three parts: LMAdapter, LMSearch and SeqBatchScheduler. There are three testing code corresponding to them: TestLMAdapter.java, TestLMSearch.java, TestSeqBatchScheduler.java. The TestSeqBatchScheduler.java utilizes all previous features.

New classes

/* 
This is a scheduler, serving as an API to the consumer of the systme, allowing for three major actions: initForward, addBatch, fastForward, collectResults. An optimal 
control sequence should be solved, after taking into the time consumption of each action, the batch size and 
sequence length of queueing requests.  Such optimal control solver needs additional effort. Primitive policy is 
setting several thresholds.
*/
SeqBatchScheduler {

    LMAdapter lmAdapter;

    SeqBatcher seqBatcher;

    public SeqBatchScheduler(LMAdapter lmAdapter) {
        this.lmAdapter = lmAdapter;
    }

    public abstract SeqBatcher initForward(
            NDArray inputIds, NDArray batchUids, SearchConfig config);

    public void incrementForward(int count) {}

    public void addRequest(NDArray inputIds, NDArray batchUids, SearchConfig config) {}

    public ConcurrentHashMap<Long, NDArray> collectResults() {}
}
// This stores the search state (BatchList), the control variables (eg seqLength, offSets, etc), and batch operations on BatchList.
public class SeqBatcher {

    NDManager manager;

    long batchSize;

    long seqLength;

    NDArray batchUid;

    NDArray offSets;

    BatchList data;

    ConcurrentHashMap<Long, NDArray> finishedSequences;

    List<Pair<Long, Long>> exitIndexEndPosition;

    SeqBatcher(BatchList data, NDArray batchUid, NDArray offSets, NDManager manager) {}

    /** Add new batch. Modify the batch dimension and the left padding. */
    public void addBatch(SeqBatcher seqBatcherNew) {   }

    /** Merge two batchers together. Modify the batch dimension and the left padding. */
    private void merge(SeqBatcher seqBatcher1, SeqBatcher seqBatcher2, long seqDelta) {}

    /** Check which batch needs to exit, according certain criteria like EOS or maxLength. */
    public void exitCriteria(NDArray outputIds, long maxLength, long eosTokenId) {}

    /** Collect the finished sequences and trim the left padding. */
    public void collectAndTrim() { }
}
// BatchList is a struct consisting of NDArrays, whose first dimension is batch, and also contains 
// sequence dimension (whose position in tensor's shape is specified by seqDimOrder)
// It represents a search state, and the NDArrays inside are updated in each iteration of the autoregressive loop.
// The batch operators in SeqBatcher will be applied on them.
public abstract class BatchList {

    // [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
    public NDArray pastOutputIds;

    // [batch, seq_past]
    // The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow.
    public NDArray pastAttentionMask;

    // (k, v) * numLayer,
    // kv: [batch, heads, seq_past, kvfeature]
    // The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow.
    public NDList pastKeyValues;

    // Sequence dimension order among all dimensions for each element in the batch list.
    public long[] seqDimOrder;

    BatchList() {}

    BatchList(NDList list, long[] seqDimOrder) {}

    BatchList(
            NDArray pastOutputIds,
            NDArray pastAttentionMask,
            NDList pastKeyValues,
            long[] seqDimOrder) {}

    public abstract BatchList fromList(NDList inputList, long[] seqDimOrder);

    // The pastOutputIds has to be the first in the output list
    public abstract NDList getList();

    public long[] getSeqDimOrder() {
        return seqDimOrder;
    }
   } 
}

This is partly adapted from the searching algorithms in TestLMSearch.java. Currently only the contrastiveSearch has been adapted. The greedySearch and beamSearch will be similarly implemented. On top of that, batching operations are implemented. The system is designed as following, which contains the batching operations like addRequest() and collectAdnTrim()

@KexinFeng KexinFeng marked this pull request as ready for review April 28, 2023 04:49
@KexinFeng KexinFeng requested review from zachgk, frankfliu and a team as code owners April 28, 2023 04:49
@KexinFeng KexinFeng marked this pull request as draft April 28, 2023 04:49
@KexinFeng KexinFeng changed the title Batch the sequence Batch the sequences with ContrastiveSeqBatchScheduler Apr 28, 2023
@KexinFeng KexinFeng marked this pull request as ready for review April 29, 2023 02:57
@KexinFeng KexinFeng force-pushed the batching branch 7 times, most recently from d021197 to e84a53d Compare May 13, 2023 02:03
zachgk
zachgk previously requested changes May 16, 2023
api/src/main/java/ai/djl/engine/Engine.java Outdated Show resolved Hide resolved
api/src/main/java/ai/djl/nn/AbstractBaseBlock.java Outdated Show resolved Hide resolved
api/src/main/java/ai/djl/translate/LMBlock.java Outdated Show resolved Hide resolved
api/src/main/java/ai/djl/translate/SeqBatchScheduler.java Outdated Show resolved Hide resolved
api/src/main/java/ai/djl/translate/SeqBatchScheduler.java Outdated Show resolved Hide resolved
api/src/main/java/ai/djl/translate/LMSearch.java Outdated Show resolved Hide resolved
api/src/main/java/ai/djl/translate/LMSearch.java Outdated Show resolved Hide resolved
@KexinFeng
Copy link
Contributor Author

Features have been merged in #2637

@KexinFeng KexinFeng closed this Jun 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants