diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index cac615d6ae..a792ce4329 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -36,6 +36,7 @@ ) from torch_tensorrt.dynamo.utils import ( get_flat_args_with_check, + get_output_meta_val, parse_graph_io, prepare_inputs, set_log_level, @@ -295,7 +296,6 @@ def compile( settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) - exported_program = pre_export_lowering(exported_program, settings) exported_program = exported_program.run_decompositions( get_decompositions(enable_experimental_decompositions) @@ -451,40 +451,18 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: ) continue - # set the submodule meta val back to the parent trt_module_node - outputs = [node for node in submodule.graph.nodes if node.op == "output"] - outputs = outputs[0].args - outputs_meta_val = [] - for ele in outputs: - # it can be a torch.fx.node.Node or a tuple of torch.fx.node.Node - if isinstance(ele, torch.fx.node.Node): - if "val" not in ele.meta: - raise ValueError( - f"node.name={ele.name}: meta['val'] does not exist, expect submodule output node has meta['val'] info" - ) - outputs_meta_val.append(ele.meta["val"]) - elif isinstance(ele, tuple): - for node in ele: - if isinstance(node, torch.fx.node.Node): - if "val" not in node.meta: - raise ValueError( - f"{node.name=}: meta['val'] does not exist, expect submodule output node has meta['val'] info" - ) - outputs_meta_val.append(node.meta["val"]) - else: - raise ValueError( - f"expect torch.fx.node.Node type, got not expected types: {type(node)=}" - ) - else: - raise ValueError( - f"expect torch.fx.node.Node or tuple of torch.fx.node.Node type, got not expected types: {type(ele)=}" - ) - if name not in submodule_node_dict: raise ValueError( f"node_name: {name} does not exist in the submodule node dictionary" ) - submodule_node_dict[name].meta["val"] = outputs_meta_val + + # set the submodule meta val back to the parent trt_module_node + if "val" not in submodule_node_dict[name].meta: + outputs = [node for node in submodule.graph.nodes if node.op == "output"] + outputs = outputs[0].args + outputs_meta_val = get_output_meta_val(outputs) + assert len(outputs_meta_val) > 0 + submodule_node_dict[name].meta["val"] = outputs_meta_val subgraph_data = PerSubgraphData() subgraph_data.subgraph_name = name diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 33bbfe8bb8..ed2e990823 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -5,7 +5,6 @@ import tensorrt as trt import torch -from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype @@ -18,35 +17,15 @@ TRTInterpreterResult, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs +from torch_tensorrt.dynamo.utils import ( + get_model_device, + get_output_dtypes, + get_torch_inputs, +) logger = logging.getLogger(__name__) -def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]: - output_dtypes = [] - if isinstance(output, torch.fx.node.Node): - if "val" in output.meta: - output_meta = output.meta["val"] - if isinstance(output_meta, (FakeTensor, torch.Tensor)): - if truncate_doulbe and output_meta.dtype == torch.float64: - output_dtypes.append(dtype.float32) - else: - output_dtypes.append(dtype._from(output_meta.dtype)) - else: - raise ValueError( - f"node.name={output.name}: node.meta['val'] does not exist, expect node.meta['val'] exists for each output node" - ) - elif isinstance(output, tuple): - for ele in output: - output_dtypes.extend(get_output_dtypes(ele)) - else: - raise ValueError( - f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple of torch.fx.node.Node" - ) - return output_dtypes - - def infer_module_output_dtypes( module: torch.fx.GraphModule, truncate_double: bool = False, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py index 9bd7ed8422..54e63a496f 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py @@ -6,6 +6,7 @@ from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) +from torch_tensorrt.dynamo.utils import get_output_meta_val, set_output_meta_val logger = logging.getLogger(__name__) @@ -14,12 +15,23 @@ def lower_linear( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: """Replace aten.linear with an equivalent implementation which can be easily converted to TRT""" + + outputs = [node for node in gm.graph.nodes if node.op == "output"] + outputs = outputs[0].args + outputs_meta_val = get_output_meta_val(outputs) + orig, replacement = linear_replacement() + replaced_nodes = torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement) - if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): + if len(replaced_nodes) > 0: gm = clean_up_graph_after_modifications(gm) logger.debug(f"Graph after lowering linear:\n{gm.graph}") + outputs = [node for node in gm.graph.nodes if node.op == "output"] + outputs = outputs[0].args + output_num = len(outputs_meta_val) + assert output_num > 0 + set_output_meta_val(outputs, outputs_meta_val) return gm diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index a85494239e..b8e1489fc4 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -661,3 +661,57 @@ def get_flat_args_with_check( flat_args_with_path, received_spec = pytree.tree_flatten_with_path((args, kwargs)) flat_args = tuple(x[1] for x in flat_args_with_path) return flat_args, received_spec + + +def get_output_meta_val(output: Any) -> List[Any]: + output_meta_val = [] + if isinstance(output, torch.fx.node.Node): + if "val" in output.meta: + output_meta_val.append(output.meta["val"]) + elif isinstance(output, tuple): + for node in output: + output_meta_val.extend(get_output_meta_val(node)) + else: + raise ValueError( + f"expect torch.fx.node.Node or a tuple of torch.fx.node.Node type, got unexpected types: {type(output)=}" + ) + return output_meta_val + + +def set_output_meta_val(output: Any, outputs_meta_val: List[Any]) -> None: + if isinstance(output, torch.fx.node.Node): + assert len(outputs_meta_val) > 0 + if "val" not in output.meta: + output.meta["val"] = outputs_meta_val[0] + outputs_meta_val.pop(0) + elif isinstance(output, tuple): + for node in output: + set_output_meta_val(node, outputs_meta_val) + else: + raise ValueError( + f"expect torch.fx.node.Node or a tuple of torch.fx.node.Node type, got unexpected types: {type(output)=}" + ) + + +def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]: + output_dtypes = [] + if isinstance(output, torch.fx.node.Node): + if "val" in output.meta: + output_meta = output.meta["val"] + if isinstance(output_meta, (FakeTensor, torch.Tensor)): + if truncate_doulbe and output_meta.dtype == torch.float64: + output_dtypes.append(dtype.float32) + else: + output_dtypes.append(dtype._from(output_meta.dtype)) + else: + raise ValueError( + f"node.name={output.name}: node.meta['val'] does not exist, expect node.meta['val'] exists for each output node" + ) + elif isinstance(output, tuple): + for ele in output: + output_dtypes.extend(get_output_dtypes(ele)) + else: + raise ValueError( + f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple of torch.fx.node.Node" + ) + return output_dtypes diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index ba6cb0c776..b6f986711a 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -29,7 +29,6 @@ def test_resnet18(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "ir": "torch_compile", "cache_built_engines": False, "reuse_cached_engines": False, } @@ -62,7 +61,6 @@ def test_mobilenet_v2(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 10, - "ir": "torch_compile", "cache_built_engines": False, "reuse_cached_engines": False, } @@ -95,7 +93,6 @@ def test_efficientnet_b0(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 10, - "ir": "torch_compile", "cache_built_engines": False, "reuse_cached_engines": False, } @@ -137,7 +134,6 @@ def test_bert_base_uncased(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 15, - "ir": "torch_compile", "cache_built_engines": False, "reuse_cached_engines": False, } @@ -173,7 +169,6 @@ def test_resnet18_half(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "ir": "torch_compile", "cache_built_engines": False, "reuse_cached_engines": False, }