Skip to content

Commit

Permalink
test(tgi): add docker tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Feb 8, 2024
1 parent e98974a commit 84c6877
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 0 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,7 @@ tgi_test: tgi_server
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
-exec python -m pip install --force-reinstall {} \;
python -m pytest -s text-generation-inference/tests

tgi_docker_test: neuronx-tgi
python -m pip install -r text-generation-inference/integration-tests/requirements.txt
python -m pytest -s text-generation-inference/integration-tests
155 changes: 155 additions & 0 deletions text-generation-inference/integration-tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import asyncio
import contextlib
import os
import random
import shlex
import subprocess
import sys
import time
from tempfile import TemporaryDirectory
from typing import List

import docker
import pytest
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound
from text_generation import AsyncClient
from text_generation.types import Response


DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "neuronx-tgi:latest")
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")


class LauncherHandle:
def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}")

def _inner_health(self):
raise NotImplementedError

async def health(self, timeout: int = 60):
assert timeout > 0
for _ in range(timeout):
if not self._inner_health():
raise RuntimeError("Launcher crashed")

try:
await self.client.generate("test")
return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
time.sleep(1)
raise RuntimeError("Health check failed")


class ContainerLauncherHandle(LauncherHandle):
def __init__(self, docker_client, container_name, port: int):
super(ContainerLauncherHandle, self).__init__(port)
self.docker_client = docker_client
self.container_name = container_name

def _inner_health(self) -> bool:
container = self.docker_client.containers.get(self.container_name)
return container.status in ["running", "created"]


class ProcessLauncherHandle(LauncherHandle):
def __init__(self, process, port: int):
super(ProcessLauncherHandle, self).__init__(port)
self.process = process

def _inner_health(self) -> bool:
return self.process.poll() is None


@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()


@pytest.fixture(scope="module")
def data_volume():
tmpdir = TemporaryDirectory()
yield tmpdir.name
# Cleanup the temporary directory using sudo as it contains root files created by the container
subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"))


@pytest.fixture(scope="module")
def launcher(event_loop, data_volume):
@contextlib.contextmanager
def docker_launcher(
model_id: str,
trust_remote_code: bool = False,
):
port = random.randint(8000, 10_000)

args = ["--model-id", model_id, "--env"]

if trust_remote_code:
args.append("--trust-remote-code")

client = docker.from_env()

container_name = f"tgi-tests-{model_id.split('/')[-1]}"

try:
container = client.containers.get(container_name)
container.stop()
container.wait()
except NotFound:
pass

env = {"LOG_LEVEL": "info,text_generation_router=debug"}

if HUGGING_FACE_HUB_TOKEN is not None:
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN

for var in ["HF_BATCH_SIZE", "HF_SEQUENCE_LENGTH", "HF_AUTOCAST_TYPE", "HF_NUM_CORES"]:
if var in os.environ:
env[var] = os.environ[var]

volumes = [f"{data_volume}:/data"]

container = client.containers.run(
DOCKER_IMAGE,
command=args,
name=container_name,
environment=env,
auto_remove=False,
detach=True,
devices=["/dev/neuron0"],
volumes=volumes,
ports={"80/tcp": port},
shm_size="1G",
)

yield ContainerLauncherHandle(client, container.name, port)

try:
container.stop()
container.wait()
except NotFound:
pass

container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)

container.remove()

return docker_launcher


@pytest.fixture(scope="module")
def generate_load():
async def generate_load_inner(client: AsyncClient, prompt: str, max_new_tokens: int, n: int) -> List[Response]:
futures = [
client.generate(prompt, max_new_tokens=max_new_tokens, decoder_input_details=True) for _ in range(n)
]

return await asyncio.gather(*futures)

return generate_load_inner
2 changes: 2 additions & 0 deletions text-generation-inference/integration-tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
asyncio_mode = auto
18 changes: 18 additions & 0 deletions text-generation-inference/integration-tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
text-generation >= 0.6.0
pytest >= 7.4.0
pytest-asyncio >= 0.21.1
docker >= 6.1.3
Levenshtein
99 changes: 99 additions & 0 deletions text-generation-inference/integration-tests/test_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os

import huggingface_hub
import Levenshtein
import pytest


MODEL_ID = "gpt2"
NEURON_MODEL_ID = "aws-neuron/gpt2-neuronx-bs4-seqlen1024"
BATCH_SIZE = 4
SEQUENCE_LENGTH = 1024
NUM_CORES = 2


@pytest.fixture(scope="module", params=["hub-neuron", "hub", "local-neuron"])
def model_name_or_path(request, data_volume):
if request.param == "hub":
os.environ["HF_BATCH_SIZE"] = str(BATCH_SIZE)
os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH)
os.environ["HF_NUM_CORES"] = str(NUM_CORES)
yield MODEL_ID
elif request.param == "hub-neuron":
yield NEURON_MODEL_ID
else:
model_dir = f"gpt2-neuron-{BATCH_SIZE}x{SEQUENCE_LENGTH}x{NUM_CORES}"
local_path = os.path.join(data_volume, model_dir)
huggingface_hub.snapshot_download(NEURON_MODEL_ID, local_dir=local_path)
# Return the path of the model inside the mounted volume
yield os.path.join("/data", model_dir)


@pytest.fixture(scope="module")
def tgi_service(launcher, model_name_or_path):
with launcher(model_name_or_path) as tgi_service:
yield tgi_service


@pytest.fixture(scope="module")
async def tgi_client(tgi_service):
await tgi_service.health(300)
return tgi_service.client


@pytest.mark.asyncio
async def test_model_single_request(tgi_client):

# Greedy bounded without input
response = await tgi_client.generate(
"What is Deep Learning?",
max_new_tokens=17,
decoder_input_details=True,
)
assert response.details.generated_tokens == 17
assert response.generated_text == "\n\nDeep learning is a new field of research that has been around for a while"

# Greedy bounded with input
response = await tgi_client.generate(
"What is Deep Learning?",
max_new_tokens=17,
return_full_text=True,
decoder_input_details=True,
)
assert response.details.generated_tokens == 17
assert (
response.generated_text
== "What is Deep Learning?\n\nDeep learning is a new field of research that has been around for a while"
)

# Sampling
response = await tgi_client.generate(
"What is Deep Learning?",
do_sample=True,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
max_new_tokens=1000,
seed=42,
decoder_input_details=True,
)
assert "The purpose of the current post is" in response.generated_text


@pytest.mark.asyncio
async def test_model_multiple_requests(tgi_client, generate_load):
num_requests = 4
responses = await generate_load(
tgi_client,
"What is Deep Learning?",
max_new_tokens=17,
n=num_requests,
)

assert len(responses) == 4
expected = "\n\nDeep learning is a new field of research that has been around for a while"
for r in responses:
assert r.details.generated_tokens == 17
# Compute the similarity with the expectation using the levenshtein distance
# We should not have more than two substitutions or additions
assert Levenshtein.distance(r.generated_text, expected) < 3

0 comments on commit 84c6877

Please sign in to comment.