Skip to content

Commit

Permalink
Merge branch 'lluo/save_remove_inputs' into lluo/switch_to_dynamo_trace
Browse files Browse the repository at this point in the history
  • Loading branch information
lanluo-nvidia committed Oct 23, 2024
2 parents 891e963 + cff64a4 commit 814262f
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 63 deletions.
40 changes: 9 additions & 31 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from torch_tensorrt.dynamo.utils import (
get_flat_args_with_check,
get_output_meta_val,
parse_graph_io,
prepare_inputs,
set_log_level,
Expand Down Expand Up @@ -295,7 +296,6 @@ def compile(

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)

exported_program = pre_export_lowering(exported_program, settings)
exported_program = exported_program.run_decompositions(
get_decompositions(enable_experimental_decompositions)
Expand Down Expand Up @@ -451,40 +451,18 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
)
continue

# set the submodule meta val back to the parent trt_module_node
outputs = [node for node in submodule.graph.nodes if node.op == "output"]
outputs = outputs[0].args
outputs_meta_val = []
for ele in outputs:
# it can be a torch.fx.node.Node or a tuple of torch.fx.node.Node
if isinstance(ele, torch.fx.node.Node):
if "val" not in ele.meta:
raise ValueError(
f"node.name={ele.name}: meta['val'] does not exist, expect submodule output node has meta['val'] info"
)
outputs_meta_val.append(ele.meta["val"])
elif isinstance(ele, tuple):
for node in ele:
if isinstance(node, torch.fx.node.Node):
if "val" not in node.meta:
raise ValueError(
f"{node.name=}: meta['val'] does not exist, expect submodule output node has meta['val'] info"
)
outputs_meta_val.append(node.meta["val"])
else:
raise ValueError(
f"expect torch.fx.node.Node type, got not expected types: {type(node)=}"
)
else:
raise ValueError(
f"expect torch.fx.node.Node or tuple of torch.fx.node.Node type, got not expected types: {type(ele)=}"
)

if name not in submodule_node_dict:
raise ValueError(
f"node_name: {name} does not exist in the submodule node dictionary"
)
submodule_node_dict[name].meta["val"] = outputs_meta_val

# set the submodule meta val back to the parent trt_module_node
if "val" not in submodule_node_dict[name].meta:
outputs = [node for node in submodule.graph.nodes if node.op == "output"]
outputs = outputs[0].args
outputs_meta_val = get_output_meta_val(outputs)
assert len(outputs_meta_val) > 0
submodule_node_dict[name].meta["val"] = outputs_meta_val

subgraph_data = PerSubgraphData()
subgraph_data.subgraph_name = name
Expand Down
31 changes: 5 additions & 26 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import tensorrt as trt
import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
Expand All @@ -18,35 +17,15 @@
TRTInterpreterResult,
)
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs
from torch_tensorrt.dynamo.utils import (
get_model_device,
get_output_dtypes,
get_torch_inputs,
)

logger = logging.getLogger(__name__)


def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]:
output_dtypes = []
if isinstance(output, torch.fx.node.Node):
if "val" in output.meta:
output_meta = output.meta["val"]
if isinstance(output_meta, (FakeTensor, torch.Tensor)):
if truncate_doulbe and output_meta.dtype == torch.float64:
output_dtypes.append(dtype.float32)
else:
output_dtypes.append(dtype._from(output_meta.dtype))
else:
raise ValueError(
f"node.name={output.name}: node.meta['val'] does not exist, expect node.meta['val'] exists for each output node"
)
elif isinstance(output, tuple):
for ele in output:
output_dtypes.extend(get_output_dtypes(ele))
else:
raise ValueError(
f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple of torch.fx.node.Node"
)
return output_dtypes


def infer_module_output_dtypes(
module: torch.fx.GraphModule,
truncate_double: bool = False,
Expand Down
14 changes: 13 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
from torch_tensorrt.dynamo.utils import get_output_meta_val, set_output_meta_val

logger = logging.getLogger(__name__)

Expand All @@ -14,12 +15,23 @@ def lower_linear(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Replace aten.linear with an equivalent implementation which can be easily converted to TRT"""

outputs = [node for node in gm.graph.nodes if node.op == "output"]
outputs = outputs[0].args
outputs_meta_val = get_output_meta_val(outputs)

orig, replacement = linear_replacement()
replaced_nodes = torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement)

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
if len(replaced_nodes) > 0:
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after lowering linear:\n{gm.graph}")

outputs = [node for node in gm.graph.nodes if node.op == "output"]
outputs = outputs[0].args
output_num = len(outputs_meta_val)
assert output_num > 0
set_output_meta_val(outputs, outputs_meta_val)
return gm


Expand Down
54 changes: 54 additions & 0 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,57 @@ def get_flat_args_with_check(
flat_args_with_path, received_spec = pytree.tree_flatten_with_path((args, kwargs))
flat_args = tuple(x[1] for x in flat_args_with_path)
return flat_args, received_spec


def get_output_meta_val(output: Any) -> List[Any]:
output_meta_val = []
if isinstance(output, torch.fx.node.Node):
if "val" in output.meta:
output_meta_val.append(output.meta["val"])
elif isinstance(output, tuple):
for node in output:
output_meta_val.extend(get_output_meta_val(node))
else:
raise ValueError(
f"expect torch.fx.node.Node or a tuple of torch.fx.node.Node type, got unexpected types: {type(output)=}"
)
return output_meta_val


def set_output_meta_val(output: Any, outputs_meta_val: List[Any]) -> None:
if isinstance(output, torch.fx.node.Node):
assert len(outputs_meta_val) > 0
if "val" not in output.meta:
output.meta["val"] = outputs_meta_val[0]
outputs_meta_val.pop(0)
elif isinstance(output, tuple):
for node in output:
set_output_meta_val(node, outputs_meta_val)
else:
raise ValueError(
f"expect torch.fx.node.Node or a tuple of torch.fx.node.Node type, got unexpected types: {type(output)=}"
)


def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]:
output_dtypes = []
if isinstance(output, torch.fx.node.Node):
if "val" in output.meta:
output_meta = output.meta["val"]
if isinstance(output_meta, (FakeTensor, torch.Tensor)):
if truncate_doulbe and output_meta.dtype == torch.float64:
output_dtypes.append(dtype.float32)
else:
output_dtypes.append(dtype._from(output_meta.dtype))
else:
raise ValueError(
f"node.name={output.name}: node.meta['val'] does not exist, expect node.meta['val'] exists for each output node"
)
elif isinstance(output, tuple):
for ele in output:
output_dtypes.extend(get_output_dtypes(ele))
else:
raise ValueError(
f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple of torch.fx.node.Node"
)
return output_dtypes
5 changes: 0 additions & 5 deletions tests/py/dynamo/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def test_resnet18(ir):
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"ir": "torch_compile",
"cache_built_engines": False,
"reuse_cached_engines": False,
}
Expand Down Expand Up @@ -62,7 +61,6 @@ def test_mobilenet_v2(ir):
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 10,
"ir": "torch_compile",
"cache_built_engines": False,
"reuse_cached_engines": False,
}
Expand Down Expand Up @@ -95,7 +93,6 @@ def test_efficientnet_b0(ir):
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 10,
"ir": "torch_compile",
"cache_built_engines": False,
"reuse_cached_engines": False,
}
Expand Down Expand Up @@ -137,7 +134,6 @@ def test_bert_base_uncased(ir):
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 15,
"ir": "torch_compile",
"cache_built_engines": False,
"reuse_cached_engines": False,
}
Expand Down Expand Up @@ -173,7 +169,6 @@ def test_resnet18_half(ir):
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"ir": "torch_compile",
"cache_built_engines": False,
"reuse_cached_engines": False,
}
Expand Down

0 comments on commit 814262f

Please sign in to comment.