diff --git a/text-generation-inference/integration-tests/conftest.py b/text-generation-inference/integration-tests/conftest.py index dd8616a3f..0017ddb57 100644 --- a/text-generation-inference/integration-tests/conftest.py +++ b/text-generation-inference/integration-tests/conftest.py @@ -1,10 +1,13 @@ import asyncio import contextlib +import logging import os import random import shlex +import string import subprocess import sys +import tempfile import time from tempfile import TemporaryDirectory from typing import List @@ -17,6 +20,7 @@ from text_generation.types import Response +LOG = logging.getLogger(__file__) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "neuronx-tgi:latest") HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") @@ -114,8 +118,33 @@ def docker_launcher( volumes = [f"{data_volume}:/data"] - container = client.containers.run( + # Workaround to bypass docker dind issues preventing to share a volume from the container running tests + # to another + docker_content = f""" + FROM {DOCKER_IMAGE} + COPY {data_volume}/. /data + """ + + docker_tag = "awesome-workaround:{}".format( + "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(5)) + ) + + LOG.info( + "Building image on the flight derivated from %s, embedding %s content, tagged with %s", DOCKER_IMAGE, + data_volume, + docker_tag, + ) + + with tempfile.NamedTemporaryFile() as f: + f.write(docker_content.encode("utf-8")) + f.flush() + image, logs = client.images.build(path=".", dockerfile=f.name, tag=docker_tag) + + LOG.debug("Build logs %s", logs) + + container = client.containers.run( + docker_tag, command=args, name=container_name, environment=env, @@ -130,15 +159,24 @@ def docker_launcher( yield ContainerLauncherHandle(client, container.name, port) try: - container.stop() - container.wait() - except NotFound: - pass - - container_output = container.logs().decode("utf-8") - print(container_output, file=sys.stderr) - - container.remove() + try: + container.stop() + container.wait() + except NotFound: + pass + container_output = container.logs().decode("utf-8") + print(container_output, file=sys.stderr) + + container.remove() + finally: + # Cleanup the build image + try: + image.remove(force=True) + except NotFound: + pass + except Exception as e: + LOG.error("Error while removing image %s, skiping", image.id) + LOG.exception(e) return docker_launcher