Skip to content

Commit

Permalink
fix windows cross compile, TODO: whether windows support stripping en…
Browse files Browse the repository at this point in the history
…gine?
  • Loading branch information
zewenli98 committed Nov 8, 2024
1 parent d57b885 commit 23d68d5
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def cross_compile_for_windows(
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
] = _defaults.ENABLED_PRECISIONS,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
make_refittable: bool = _defaults.MAKE_REFITTABLE,
debug: bool = _defaults.DEBUG,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
workspace_size: int = _defaults.WORKSPACE_SIZE,
Expand Down Expand Up @@ -93,6 +92,7 @@ def cross_compile_for_windows(
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
**kwargs: Any,
) -> torch.fx.GraphModule:
Expand Down Expand Up @@ -132,7 +132,6 @@ def cross_compile_for_windows(
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
refit (bool): Enable refitting
debug (bool): Enable debuggable engine
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
Expand Down Expand Up @@ -164,6 +163,7 @@ def cross_compile_for_windows(
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.
**kwargs: Any,
Returns:
Expand Down Expand Up @@ -193,14 +193,17 @@ def cross_compile_for_windows(

if "refit" in kwargs.keys():
warnings.warn(
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
"`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)

if "make_refittable" in kwargs.keys():
warnings.warn(
"`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)
if make_refittable:
raise ValueError("Use flag make_refittable only. Flag refit is deprecated.")
else:
make_refittable = kwargs["refit"]

engine_capability = EngineCapability._from(engine_capability)

Expand Down Expand Up @@ -275,7 +278,6 @@ def cross_compile_for_windows(
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"make_refittable": make_refittable,
"engine_capability": engine_capability,
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
Expand All @@ -286,6 +288,7 @@ def cross_compile_for_windows(
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"immutable_weights": immutable_weights,
"enable_cross_compile_for_windows": True,
"enable_weight_streaming": enable_weight_streaming,
}
Expand Down

0 comments on commit 23d68d5

Please sign in to comment.