diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index dda014890d..5fcccb5c77 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -4,8 +4,8 @@ import torch from torch._decomp import register_decomposition -from torch._export.utils import _decomp_table_to_post_autograd_aten from torch._ops import OpOverload +from torch.export import default_decompositions from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim from torch_tensorrt.dynamo.utils import to_torch_device @@ -412,7 +412,8 @@ def get_decompositions( return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **TORCH_TRT_DECOMPOSITIONS} else: # changes made here due to torch2.6 changes https://github.com/pytorch/pytorch/pull/135080 - decomp_table = _decomp_table_to_post_autograd_aten() + # changes made here due to torch2.6 changes https://github.com/pytorch/pytorch/pull/140085 + decomp_table = default_decompositions() DECOMP_TABLE_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = { decomp: decomp_table[decomp] for decomp in decomp_table