Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] implements text-generation search algorithm #2637

Merged
merged 10 commits into from
Jun 27, 2023

Conversation

KexinFeng
Copy link
Contributor

@KexinFeng KexinFeng commented Jun 6, 2023

This PR succeeds PR #2547 #2509 #2557, #2572 which contains the benchmark outputs of the searching results.

This PR contains only the features of LMSearch.

djl/examples/src/main/java/ai/djl/examples/inference/GPTInference.java contains the front_end design.

The model conversion to torchscript and onnx

See the Model tracing section in #2547 #2509 's PR description.

Demonstration

The PR #2723 provides several examples to demonstrate the usage of the language model text generation.

@KexinFeng KexinFeng requested review from zachgk, frankfliu and a team as code owners June 6, 2023 19:26
@KexinFeng KexinFeng force-pushed the LMSearch branch 8 times, most recently from 4490def to 24d3a8e Compare June 9, 2023 19:59
@KexinFeng KexinFeng changed the title Lm search [api] LMSearch Jun 9, 2023
@KexinFeng KexinFeng force-pushed the LMSearch branch 2 times, most recently from 95ddf0e to 8d91ef7 Compare June 13, 2023 21:56
private boolean suffixPadding;

/** Constructs a new ContrastiveSearchConfig object with default values. */
public SearchConfig() {
Copy link

@jawaff jawaff Jun 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any plans to support different configurations since not all of the text generation models are the same? I'm personally more interested in T5 than GPT2. T5 in particular is a different beast with both a decoder and encoder in contrast to GPT2's decoder-only approach. T5 also supports over a hundred special tokens. There's 100 "extra" tokens that can be used for a variety of things including fill masks and potentially representing special words/instructions in the generated output.

https://huggingface.co/transformers/v3.0.2/model_doc/t5.html#t5tokenizer

There probably needs to be different configurations and generation classes for each of the family of models out there. If we hardcode everything to GPT2, then there's going to be breaking changes in the future. I'd suggest adding support for two different models starting out and coming up with a solution for adding support for others in the future..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About the searchConfig, I'm thinking of just adding parameters into it. Not necessarily all of them are used in a single model. This should solve the issue about different search configurations you mentioned, right?

@jawaff
Copy link

jawaff commented Jun 20, 2023

I'm just leaving my 2 cents since I'm interested in your work. It's great to see it get added to DJL. My only real fear is that there might need to be a lot of refactoring to get support for other models. Flan-T5 is one of the most powerful open source models (that supports commercial use) we have available and it has a variety of sizes available. I'd be most interested in seeing it be supported.

I'm not super familiar with GPT2 aside from it being a decoder-only model. There's a chance that T5 and GPT2 share some similarities in the decoder aspect, but T5 has an initial encoder pass on the initial inputs. The hidden state of the encoded inputs are then used for each pass of the decoder alongside the ids that have been currently selected for generation.

@jawaff
Copy link

jawaff commented Jun 20, 2023

That's all I've got to add, good work. I just want to see this turn into a bigger feature beyond what you're working on.

@frankfliu
Copy link
Contributor

That's all I've got to add, good work. I just want to see this turn into a bigger feature beyond what you're working on.

We are planning to add T5 model. This is just a starting point to add textgeneration support.

@KexinFeng
Copy link
Contributor Author

@jawaff Thanks for pointing out the encoder-decoder model T5 to us and reminding us of the possible refactoring.

I think to implement encoder-decoder model, the major edition will be in the search algorithms, where we will need an if (encoderDecoder is true) block, which computes the encoding). The rest part of the code will basically be shared. This structure is seen in huggingface transformer.

This was referenced Jun 21, 2023
@codecov-commenter
Copy link

codecov-commenter commented Jun 23, 2023

Codecov Report

Patch coverage: 54.83% and project coverage change: -0.03 ⚠️

Comparison is base (bb5073f) 72.08% compared to head (8242299) 72.06%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@             Coverage Diff              @@
##             master    #2637      +/-   ##
============================================
- Coverage     72.08%   72.06%   -0.03%     
- Complexity     5126     7020    +1894     
============================================
  Files           473      698     +225     
  Lines         21970    31252    +9282     
  Branches       2351     3224     +873     
============================================
+ Hits          15838    22521    +6683     
- Misses         4925     7200    +2275     
- Partials       1207     1531     +324     
Impacted Files Coverage Δ
api/src/main/java/ai/djl/modality/cv/Image.java 69.23% <ø> (-4.11%) ⬇️
...rc/main/java/ai/djl/modality/cv/MultiBoxPrior.java 76.00% <ø> (ø)
.../main/java/ai/djl/modality/cv/output/Landmark.java 100.00% <ø> (ø)
...djl/modality/cv/transform/RandomFlipLeftRight.java 25.00% <0.00%> (-25.00%) ⬇️
...djl/modality/cv/transform/RandomFlipTopBottom.java 25.00% <0.00%> (-25.00%) ⬇️
...i/djl/modality/cv/translator/BigGANTranslator.java 21.42% <0.00%> (-5.24%) ⬇️
.../modality/cv/translator/ImageFeatureExtractor.java 0.00% <0.00%> (ø)
.../ai/djl/modality/cv/translator/YoloTranslator.java 27.77% <0.00%> (+18.95%) ⬆️
...ain/java/ai/djl/modality/cv/util/NDImageUtils.java 67.10% <0.00%> (+7.89%) ⬆️
api/src/main/java/ai/djl/modality/nlp/Decoder.java 63.63% <ø> (ø)
... and 227 more

... and 368 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@frankfliu frankfliu changed the title [api] LMSearch [api] implements text-generation search algorithm Jun 27, 2023
@KexinFeng KexinFeng merged commit 68c7a03 into deepjavalibrary:master Jun 27, 2023
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants