diff --git a/docsrc/index.rst b/docsrc/index.rst index b4ede94404..5d88c8ecae 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -66,6 +66,7 @@ Tutorials * :ref:`converter_overloading` * :ref:`custom_kernel_plugins` * :ref:`mutable_torchtrt_module_example` +* :ref:`weight_streaming_example` .. toctree:: :caption: Tutorials @@ -82,6 +83,7 @@ Tutorials tutorials/_rendered_examples/dynamo/converter_overloading tutorials/_rendered_examples/dynamo/custom_kernel_plugins tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example + tutorials/_rendered_examples/dynamo/weight_streaming_example Dynamo Frontend ---------------- diff --git a/examples/dynamo/weight_streaming_example.py b/examples/dynamo/weight_streaming_example.py index f88e16b2ea..6ab1bb182b 100644 --- a/examples/dynamo/weight_streaming_example.py +++ b/examples/dynamo/weight_streaming_example.py @@ -62,6 +62,7 @@ def time_generate(model, inputs, output_seq_length, iterations=10): return time_mean_ms +# Load the LLaMA-2 model DEVICE = torch.device("cuda:0") llama_path = "meta-llama/Llama-2-7b-chat-hf" with torch.no_grad(): @@ -69,13 +70,18 @@ def time_generate(model, inputs, output_seq_length, iterations=10): llama_path, use_cache=False, attn_implementation="eager" ).eval() +# Set input and output sequence lengths isl = 128 osl = 256 +# Create random input tensors input_tensors = [torch.randint(0, 5, (1, isl), dtype=torch.int64).cuda()] +# Convert the model to half precision (FP16) model = model.half() with torch.no_grad(): + # Define a dynamic dimension for sequence length seq_len = torch.export.Dim("seq_len", min=1, max=osl) + # Export the model with dynamic shapes # strict=False only enables aotautograd tracing and excludes dynamo. llama2_ep = torch.export.export( model, tuple(input_tensors), dynamic_shapes=({1: seq_len},), strict=False @@ -89,6 +95,8 @@ def time_generate(model, inputs, output_seq_length, iterations=10): # the engine with weight streaming feature. use_explicit_typing=True option creates a # `strongly typed network `_ and only float32 precision is allowed in enabled_precisions option # + +# Create a TensorRT-compiled model trt_model = torch_tensorrt.dynamo.compile( llama2_ep, inputs=input_tensors, @@ -106,11 +114,15 @@ def time_generate(model, inputs, output_seq_length, iterations=10): # Running with automatic budget size # ---------------------------------- # -# Once you specify the enable_weight_streaming option, automatic budget size is configured. +# Once you specify the enable_weight_streaming compile option, automatic budget size is configured. # This automatic size may not always provide the optimal solution because the automatically determined # budget lacks insight into the user's specific memory constraints and usage patterns + +# Weight streaming context to get current weight budget information weight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(trt_model) +# Measure the mean latency of the model with weight streaming mean_latency = time_generate(trt_model, input_tensors, osl, 1) +# Calculate the percentage of current weight budget used weight_budget_pct = ( weight_streaming_ctx.device_budget / weight_streaming_ctx.total_device_budget * 100 ) @@ -128,15 +140,19 @@ def time_generate(model, inputs, output_seq_length, iterations=10): # equal to ctx.total_device_budget will disable weight streaming. # If multiple trt engines are created, budgets are distributed proportionally +# Use a context manager for weight streaming with torch_tensorrt.runtime.weight_streaming(trt_model) as weight_streaming_ctx: - # The size of the streamable weights in the engine + # Get the total size of streamable weights in the engine streamable_budget = weight_streaming_ctx.total_device_budget - # get automatic weight streaming budget size by using get_automatic_weight_streaming_budget + # Scenario 1: Automatic weight streaming budget + # Get the automatically determined weight streaming budget requested_budget = weight_streaming_ctx.get_automatic_weight_streaming_budget() - # Set and get the current weight streaming budget for inference + # Set the device budget to the automatically determined value weight_streaming_ctx.device_budget = requested_budget + # Measure the mean latency with automatic budget mean_latency = time_generate(trt_model, input_tensors, osl, 1) + # Calculate the percentage of the weight budget used weight_budget_pct = ( weight_streaming_ctx.device_budget / weight_streaming_ctx.total_device_budget @@ -146,10 +162,13 @@ def time_generate(model, inputs, output_seq_length, iterations=10): f"Set auto weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms" ) - # Set 10% of weight streaming budget + # Scenario 2: Manual 10% weight streaming budget + # Set the budget to 10% of the total streamable weights requested_budget = int(streamable_budget * 0.1) weight_streaming_ctx.device_budget = requested_budget + # Measure the mean latency with 10% budget mean_latency = time_generate(trt_model, input_tensors, osl, 1) + # Calculate the percentage of the weight budget used weight_budget_pct = ( weight_streaming_ctx.device_budget / weight_streaming_ctx.total_device_budget