-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Set up tgi environment values with the ones used to build the model (#…
…529) * Set up tgi environment values with the ones used to build the model Need this to workaround the model static params, for the docker entrypoint to adapt tgi environment accordingly to the specified model This will make usage of the image easier: default params (e.g not specifying anything) should be enough for most models Signed-off-by: Raphael Glon <[email protected]> * fixes Signed-off-by: Raphael Glon <[email protected]> * fixes Signed-off-by: Raphael Glon <[email protected]> * minor: logging Signed-off-by: Raphael Glon <[email protected]> * Integration tests for inf2 + tgi_env wrapper Signed-off-by: Raphael Glon <[email protected]> * Github ci worklow Signed-off-by: Raphael Glon <[email protected]> * To run on github ci we cannot share a volume but need to embed the model within an image built on the flight Signed-off-by: Raphael Glon <[email protected]> * More flexible on expected outputs Signed-off-by: Raphael Glon <[email protected]> * Be more flexible about compiler version Signed-off-by: Raphael Glon <[email protected]> * Refacto bump dev version, use a single workflow for TGI, simplify a bit the implicit env test Signed-off-by: Raphael Glon <[email protected]> * Misc fixes Signed-off-by: Raphael Glon <[email protected]> --------- Signed-off-by: Raphael Glon <[email protected]> Co-authored-by: Raphael Glon <[email protected]>
- Loading branch information
Showing
8 changed files
with
387 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
text-generation-inference/integration-tests/test_implicit_env.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import os | ||
|
||
import pytest | ||
from text_generation.errors import ValidationError | ||
|
||
|
||
# These tests will often break as it relies on many factors like the optimum version, the neuronx-cc version, | ||
# and on what is synced in the cache for these specific versions... | ||
|
||
MODELS = ["openai-community/gpt2", "aws-neuron/gpt2-neuronx-bs4-seqlen1024"] | ||
|
||
|
||
@pytest.fixture(scope="module", params=MODELS) | ||
def get_model_and_set_env(request): | ||
# the tgi_env.py script will take care of setting these | ||
for var in [ | ||
"MAX_BATCH_SIZE", | ||
"MAX_INPUT_LENGTH", | ||
"MAX_TOTAL_TOKEN", | ||
"HF_BATCH_SIZE", | ||
"HF_NUM_CORES", | ||
"HF_SEQUENCE_LENGTH", | ||
"HF_AUTO_CAST_TYPE", | ||
]: | ||
if var in os.environ: | ||
del os.environ[var] | ||
yield request.param | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def tgi_service(launcher, get_model_and_set_env): | ||
with launcher(get_model_and_set_env) as tgi_service: | ||
yield tgi_service | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
async def tgi_client(tgi_service): | ||
await tgi_service.health(300) | ||
return tgi_service.client | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_model_single_request(tgi_client): | ||
|
||
# Just verify that the generation works, and nothing is raised, with several set of params | ||
|
||
# No params | ||
await tgi_client.generate( | ||
"What is Deep Learning?", | ||
) | ||
|
||
response = await tgi_client.generate( | ||
"How to cook beans ?", | ||
max_new_tokens=17, | ||
decoder_input_details=True, | ||
) | ||
assert response.details.generated_tokens == 17 | ||
|
||
# check error | ||
try: | ||
await tgi_client.generate("What is Deep Learning?", max_new_tokens=170000) | ||
except ValidationError: | ||
pass | ||
else: | ||
raise AssertionError( | ||
"The previous text generation request should have failed, " | ||
"because too many tokens were requested, it succeeded" | ||
) | ||
|
||
# Sampling | ||
await tgi_client.generate( | ||
"What is Deep Learning?", | ||
do_sample=True, | ||
top_k=50, | ||
top_p=0.9, | ||
repetition_penalty=1.2, | ||
max_new_tokens=1000, | ||
seed=42, | ||
decoder_input_details=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#!/bin/bash | ||
set -e -o pipefail -u | ||
|
||
export ENV_FILEPATH=$(mktemp) | ||
|
||
trap "rm -f ${ENV_FILEPATH}" EXIT | ||
|
||
touch $ENV_FILEPATH | ||
|
||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) | ||
|
||
${SCRIPT_DIR}/tgi_env.py $@ | ||
|
||
source $ENV_FILEPATH | ||
|
||
text-generation-launcher $@ |
Oops, something went wrong.