Skip to content

Commit

Permalink
chore: rebase and update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Oct 18, 2024
1 parent fa407bc commit 3ac5da1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Tutorials
* :ref:`converter_overloading`
* :ref:`custom_kernel_plugins`
* :ref:`mutable_torchtrt_module_example`
* :ref:`weight_streaming_example`

.. toctree::
:caption: Tutorials
Expand All @@ -82,6 +83,7 @@ Tutorials
tutorials/_rendered_examples/dynamo/converter_overloading
tutorials/_rendered_examples/dynamo/custom_kernel_plugins
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
tutorials/_rendered_examples/dynamo/weight_streaming_example

Dynamo Frontend
----------------
Expand Down
29 changes: 24 additions & 5 deletions examples/dynamo/weight_streaming_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,26 @@ def time_generate(model, inputs, output_seq_length, iterations=10):
return time_mean_ms


# Load the LLaMA-2 model
DEVICE = torch.device("cuda:0")
llama_path = "meta-llama/Llama-2-7b-chat-hf"
with torch.no_grad():
model = AutoModelForCausalLM.from_pretrained(
llama_path, use_cache=False, attn_implementation="eager"
).eval()

# Set input and output sequence lengths
isl = 128
osl = 256

# Create random input tensors
input_tensors = [torch.randint(0, 5, (1, isl), dtype=torch.int64).cuda()]
# Convert the model to half precision (FP16)
model = model.half()
with torch.no_grad():
# Define a dynamic dimension for sequence length
seq_len = torch.export.Dim("seq_len", min=1, max=osl)
# Export the model with dynamic shapes
# strict=False only enables aotautograd tracing and excludes dynamo.
llama2_ep = torch.export.export(
model, tuple(input_tensors), dynamic_shapes=({1: seq_len},), strict=False
Expand All @@ -89,6 +95,8 @@ def time_generate(model, inputs, output_seq_length, iterations=10):
# the engine with weight streaming feature. use_explicit_typing=True option creates a
# `strongly typed network <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#strongly-typed-networks>`_ and only float32 precision is allowed in enabled_precisions option
#

# Create a TensorRT-compiled model
trt_model = torch_tensorrt.dynamo.compile(
llama2_ep,
inputs=input_tensors,
Expand All @@ -106,11 +114,15 @@ def time_generate(model, inputs, output_seq_length, iterations=10):
# Running with automatic budget size
# ----------------------------------
#
# Once you specify the enable_weight_streaming option, automatic budget size is configured.
# Once you specify the enable_weight_streaming compile option, automatic budget size is configured.
# This automatic size may not always provide the optimal solution because the automatically determined
# budget lacks insight into the user's specific memory constraints and usage patterns

# Weight streaming context to get current weight budget information
weight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(trt_model)
# Measure the mean latency of the model with weight streaming
mean_latency = time_generate(trt_model, input_tensors, osl, 1)
# Calculate the percentage of current weight budget used
weight_budget_pct = (
weight_streaming_ctx.device_budget / weight_streaming_ctx.total_device_budget * 100
)
Expand All @@ -128,15 +140,19 @@ def time_generate(model, inputs, output_seq_length, iterations=10):
# equal to ctx.total_device_budget will disable weight streaming.
# If multiple trt engines are created, budgets are distributed proportionally

# Use a context manager for weight streaming
with torch_tensorrt.runtime.weight_streaming(trt_model) as weight_streaming_ctx:
# The size of the streamable weights in the engine
# Get the total size of streamable weights in the engine
streamable_budget = weight_streaming_ctx.total_device_budget

# get automatic weight streaming budget size by using get_automatic_weight_streaming_budget
# Scenario 1: Automatic weight streaming budget
# Get the automatically determined weight streaming budget
requested_budget = weight_streaming_ctx.get_automatic_weight_streaming_budget()
# Set and get the current weight streaming budget for inference
# Set the device budget to the automatically determined value
weight_streaming_ctx.device_budget = requested_budget
# Measure the mean latency with automatic budget
mean_latency = time_generate(trt_model, input_tensors, osl, 1)
# Calculate the percentage of the weight budget used
weight_budget_pct = (
weight_streaming_ctx.device_budget
/ weight_streaming_ctx.total_device_budget
Expand All @@ -146,10 +162,13 @@ def time_generate(model, inputs, output_seq_length, iterations=10):
f"Set auto weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms"
)

# Set 10% of weight streaming budget
# Scenario 2: Manual 10% weight streaming budget
# Set the budget to 10% of the total streamable weights
requested_budget = int(streamable_budget * 0.1)
weight_streaming_ctx.device_budget = requested_budget
# Measure the mean latency with 10% budget
mean_latency = time_generate(trt_model, input_tensors, osl, 1)
# Calculate the percentage of the weight budget used
weight_budget_pct = (
weight_streaming_ctx.device_budget
/ weight_streaming_ctx.total_device_budget
Expand Down

0 comments on commit 3ac5da1

Please sign in to comment.