Skip to content

Commit

Permalink
Bump TGI version and fix bugs (#666)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Benayoun <[email protected]>
  • Loading branch information
dacorvo and michaelbenayoun authored Jul 29, 2024
1 parent c63980b commit 0332ba6
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ PACKAGE_FILES = $(PACKAGE_PYTHON_FILES) \
$(PACKAGE_DIST) $(PACKAGE_WHEEL): $(PACKAGE_FILES)
python -m build

TGI_VERSION ?= 2.0.2
TGI_VERSION ?= 2.1.1

neuronx-tgi: $(PACKAGE_DIST)
docker build --rm -f text-generation-inference/Dockerfile \
Expand Down
6 changes: 3 additions & 3 deletions text-generation-inference/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ RUN tar -C /tgi -xf /tgi/sources.tar.gz --strip-components=1

# Build cargo components (adapted from TGI original Dockerfile)
# Note that the build image is aligned on the same Linux version as the base image (Debian bookworm/ Ubuntu 22.04)
FROM lukemathwalker/cargo-chef:latest-rust-1.75-bookworm AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.79-bookworm AS chef
WORKDIR /usr/src

ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
Expand All @@ -20,8 +20,6 @@ COPY --from=tgi /tgi/proto proto
COPY --from=tgi /tgi/benchmark benchmark
COPY --from=tgi /tgi/router router
COPY --from=tgi /tgi/launcher launcher
# Remove the next line when bumping rust version
RUN cargo update ravif --precise 0.11.6
RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder
Expand All @@ -41,6 +39,8 @@ COPY --from=tgi /tgi/proto proto
COPY --from=tgi /tgi/benchmark benchmark
COPY --from=tgi /tgi/router router
COPY --from=tgi /tgi/launcher launcher
# Remove this line once TGI has fixed the conflict
RUN cargo update ureq --precise 2.9.7
RUN cargo build --release --workspace --exclude benchmark

# Python base image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def serve(
uds_path: str = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-generation-inference.server",
max_input_tokens: Optional[int] = None,
):
"""This is the main entry-point for the server CLI.
Expand All @@ -36,6 +39,12 @@ def serve(
The server logger level. Defaults to *INFO*.
json_output (`bool`):
Use JSON format for log serialization.
otlp_endpoint (`Optional[str]`, defaults to `None`):
The Open Telemetry endpoint to use.
otlp_service_name (`Optional[str]`, defaults to `None`):
The name to use when pushing data to the Open Telemetry endpoint.
max_input_tokens (`Optional[int]`, defaults to `None`):
The maximum number of input tokens each request should contain.
"""
if sharded:
raise ValueError("Sharding is not supported.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,18 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
# just carry on with decoding. We adopt the id of the first
# batch in the list as our next batch id.
next_batch_id = batches[0].id
request_ids = []
for batch in batches:
request_ids += batch.request_ids
cleared_request_ids = []
for slot in self.slots:
if slot.state == slot.State.READY and slot.request_id not in request_ids:
cleared_request_ids.append(slot.request_id)
slot.clear()
if len(cleared_request_ids) > 0:
logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.")
active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
if len(active_slots) == 0:
if len(active_slots) < len(request_ids):
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
if self.model.continuous_batching:
decode_slots = active_slots
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ async def tgi_service(request, launcher, neuron_model_config):
# the tgi_env.py script will take care of setting these
for var in [
"MAX_BATCH_SIZE",
"MAX_INPUT_LENGTH",
"MAX_TOTAL_TOKEN",
"MAX_INPUT_TOKENS",
"MAX_TOTAL_TOKENS",
"HF_NUM_CORES",
"HF_AUTO_CAST_TYPE",
]:
Expand Down
30 changes: 16 additions & 14 deletions text-generation-inference/tgi_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

logger = logging.getLogger(__name__)

tgi_router_env_vars = ["MAX_BATCH_SIZE", "MAX_TOTAL_TOKENS", "MAX_INPUT_LENGTH"]
tgi_router_env_vars = ["MAX_BATCH_SIZE", "MAX_TOTAL_TOKENS", "MAX_INPUT_TOKENS"]
tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"]

env_config_peering = [
Expand All @@ -38,7 +38,9 @@ def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
if not argv:
argv = sys.argv
# All these are params passed to tgi and intercepted here
parser.add_argument("--max-input-length", type=int, default=os.getenv("MAX_INPUT_LENGTH", 0))
parser.add_argument(
"--max-input-tokens", type=int, default=os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))
)
parser.add_argument("--max-total-tokens", type=int, default=os.getenv("MAX_TOTAL_TOKENS", 0))
parser.add_argument("--max-batch-size", type=int, default=os.getenv("MAX_BATCH_SIZE", 0))
parser.add_argument("--model-id", type=str, default=os.getenv("MODEL_ID"))
Expand All @@ -57,8 +59,8 @@ def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
if args.max_total_tokens > 0:
os.environ["MAX_TOTAL_TOKENS"] = str(args.max_total_tokens)

if args.max_input_length > 0:
os.environ["MAX_INPUT_LENGTH"] = str(args.max_input_length)
if args.max_input_tokens > 0:
os.environ["MAX_INPUT_TOKENS"] = str(args.max_input_tokens)

if args.max_batch_size > 0:
os.environ["MAX_BATCH_SIZE"] = str(args.max_batch_size)
Expand All @@ -73,12 +75,12 @@ def neuron_config_to_env(neuron_config):
with open(os.environ["ENV_FILEPATH"], "w") as f:
for env_var, config_key in env_config_peering:
f.write("export {}={}\n".format(env_var, neuron_config[config_key]))
max_input_length = os.getenv("MAX_INPUT_LENGTH")
if not max_input_length:
max_input_length = int(neuron_config["sequence_length"]) // 2
if max_input_length == 0:
max_input_tokens = os.getenv("MAX_INPUT_TOKENS")
if not max_input_tokens:
max_input_tokens = int(neuron_config["sequence_length"]) // 2
if max_input_tokens == 0:
raise Exception("Model sequence length should be greater than 1")
f.write("export MAX_INPUT_LENGTH={}\n".format(max_input_length))
f.write("export MAX_INPUT_TOKENS={}\n".format(max_input_tokens))


def sort_neuron_configs(dictionary):
Expand Down Expand Up @@ -149,13 +151,13 @@ def check_env_and_neuron_config_compatibility(neuron_config: Dict[str, Any], che
)
return False

if os.getenv("MAX_INPUT_LENGTH"):
max_input_length = int(os.environ["MAX_INPUT_LENGTH"])
max_input_tokens = int(os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0)))
if max_input_tokens > 0:
sequence_length = neuron_config["sequence_length"]
if max_input_length >= sequence_length:
if max_input_tokens >= sequence_length:
logger.debug(
"Specified max input length is not compatible with config sequence length " "( %s >= %s)",
max_input_length,
"Specified max input tokens is not compatible with config sequence length " "( %s >= %s)",
max_input_tokens,
sequence_length,
)
return False
Expand Down

0 comments on commit 0332ba6

Please sign in to comment.