Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Nov 6, 2024
1 parent 1ae33f4 commit 90bf679
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 22 deletions.
15 changes: 6 additions & 9 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,16 +466,13 @@ def refit_module_weights(
serialization_config = engine.create_serialization_config()
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
serialized_engine = engine.serialize_with_config(serialization_config)
engine = runtime.deserialize_cuda_engine(serialized_engine)

if isinstance(compiled_submodule, PythonTorchTensorRTModule):
compiled_submodule.engine = engine

if isinstance(compiled_submodule, TorchTensorRTModule):
new_engine_info = list(engine_info)
new_engine_info[ENGINE_IDX] = bytes(serialized_engine)
refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info))
compiled_submodule.engine = refitted_engine
if isinstance(
compiled_submodule, (PythonTorchTensorRTModule, TorchTensorRTModule)
):
compiled_submodule.engine = None # Clear the engine for TorchTensorRTModule, otherwise it won't be updated
compiled_submodule.serialized_engine = bytes(serialized_engine)
compiled_submodule.setup_engine()

elif inline_module:
new_engine_info = list(engine_info)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class CompilationSettings:
"engine_capability",
"hardware_compatible",
"refit_identical_engine_weights",
"strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default?
"immutable_weights",
"enable_weight_streaming",
)
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def _pretraced_backend(
logger.warning(
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
)
if settings.strip_engine_weights:
logger.warning(
"strip_engine_weights arg is not supported for torch.compile()"
)
trt_compiled = compile_module(
gm,
torchtrt_inputs,
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def _save_weight_mapping(self) -> None:
torch.cuda.empty_cache()

def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
# TODO: Waiting for TRT's feature to cache the weight-stripped engine
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
# if not self.compilation_settings.strip_engine_weights:
# # set EXCLUDE_WEIGHTS flag to strip weights
# runtime = trt.Runtime(TRT_LOGGER)
Expand Down Expand Up @@ -624,7 +624,7 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
)
serialized_engine = engine.serialize()

# TODO: Waiting for TRT's feature to load the weight-stripped engine
# TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine
# # EXCLUDE_WEIGHTS flag must be cleared
# serialization_config = engine.create_serialization_config()
# serialization_config.clear_flag(
Expand Down
18 changes: 7 additions & 11 deletions tests/py/dynamo/models/test_weight_stripped_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ def test_three_ways_to_compile_weight_stripped_engine(self):
)
gm2_output = gm2(*example_inputs)

# 3. Compile with torch.compile using tensorrt backend
gm3 = torch.compile(
pyt_model,
backend="tensorrt",
options=settings,
)
gm3_output = gm3(*example_inputs)
# 3. Compile with torch.compile using tensorrt backend, which is not supported to set strip_engine_weights=True
# gm3 = torch.compile(
# pyt_model,
# backend="tensorrt",
# options=settings,
# )
# gm3_output = gm3(*example_inputs)

assertions.assertEqual(
gm1_output.sum(), 0, msg="gm1_output should be all zeros"
Expand All @@ -110,10 +110,6 @@ def test_three_ways_to_compile_weight_stripped_engine(self):
gm2_output.sum(), 0, msg="gm2_output should be all zeros"
)

assertions.assertEqual(
gm3_output.sum(), 0, msg="gm3_output should be all zeros"
)

def test_weight_stripped_engine_sizes(self):
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
Expand Down

0 comments on commit 90bf679

Please sign in to comment.