From 53d2188df9930a5110ae286224cfd9867797acd8 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Thu, 14 Mar 2024 16:55:18 -0700 Subject: [PATCH] Update get_aten_graph_module (#121937) Pull Request resolved: https://github.com/pytorch/pytorch/pull/121937 Approved by: https://github.com/andrewor14 --- docs/source/conf.py | 1 - test/allowlist_for_publicAPI.json | 1 - torch/ao/quantization/pt2e/export_utils.py | 24 +++++++++++-------- torch/ao/quantization/pt2e/qat_utils.py | 12 +++++----- .../pt2e/representation/rewrite.py | 6 ++--- torch/ao/quantization/pt2e/utils.py | 22 ++++++++++++----- .../quantizer/xnnpack_quantizer_utils.py | 4 ++-- 7 files changed, 41 insertions(+), 29 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 3c85c9c110df4..5b0e763032c22 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -309,7 +309,6 @@ "reference_representation_rewrite", # torch.ao.quantization.pt2e.utils "fold_bn_weights_into_conv_node", - "get_aten_graph_module", "remove_tensor_overload_for_qdq_ops", # torch.ao.quantization.qconfig "get_default_qat_qconfig", diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 44b64324281f5..9bc60578ea7a2 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1519,7 +1519,6 @@ "SharedQuantizationSpec", "Tuple", "fold_bn_weights_into_conv_node", - "get_aten_graph_module", "replace_pattern_with_filters" ], "torch.ao.quantization.quantize_fx": [ diff --git a/torch/ao/quantization/pt2e/export_utils.py b/torch/ao/quantization/pt2e/export_utils.py index d73319df019b1..dae8baad8d28c 100644 --- a/torch/ao/quantization/pt2e/export_utils.py +++ b/torch/ao/quantization/pt2e/export_utils.py @@ -46,7 +46,7 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool): See https://github.com/pytorch/pytorch/issues/103681. """ # Avoid circular dependencies - from .utils import get_aten_graph_module + from .utils import _get_aten_graph_module_for_pattern # Needed to ensure subgraph matches are self-contained m.graph.eliminate_dead_code() @@ -62,17 +62,17 @@ def dropout_eval(x): example_inputs = (torch.randn(1),) if train_to_eval: - match_pattern = get_aten_graph_module( + match_pattern = _get_aten_graph_module_for_pattern( _WrapperModule(dropout_train), example_inputs ) - replacement_pattern = get_aten_graph_module( + replacement_pattern = _get_aten_graph_module_for_pattern( _WrapperModule(dropout_eval), example_inputs ) else: - match_pattern = get_aten_graph_module( + match_pattern = _get_aten_graph_module_for_pattern( _WrapperModule(dropout_eval), example_inputs ) - replacement_pattern = get_aten_graph_module( + replacement_pattern = _get_aten_graph_module_for_pattern( _WrapperModule(dropout_train), example_inputs ) @@ -101,7 +101,7 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool): # Enable this support in future updates. # Avoid circular dependencies - from .utils import get_aten_graph_module + from .utils import _get_aten_graph_module_for_pattern # Needed to ensure subgraph matches are self-contained m.graph.eliminate_dead_code() @@ -137,13 +137,17 @@ def bn_eval( torch.randn(1), # bn_running_var ) if train_to_eval: - match_pattern = get_aten_graph_module(_WrapperModule(bn_train), example_inputs) - replacement_pattern = get_aten_graph_module( + match_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(bn_train), example_inputs + ) + replacement_pattern = _get_aten_graph_module_for_pattern( _WrapperModule(bn_eval), example_inputs ) else: - match_pattern = get_aten_graph_module(_WrapperModule(bn_eval), example_inputs) - replacement_pattern = get_aten_graph_module( + match_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(bn_eval), example_inputs + ) + replacement_pattern = _get_aten_graph_module_for_pattern( _WrapperModule(bn_train), example_inputs ) diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index d50a2f608e271..0e7071e5ffae6 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -24,7 +24,7 @@ _is_conv, _is_bn_node, fold_bn_weights_into_conv_node, - get_aten_graph_module, + _get_aten_graph_module_for_pattern, ) if TYPE_CHECKING: @@ -546,7 +546,7 @@ def _fuse_conv_bn_qat_helper( m.graph.eliminate_dead_code() m.recompile() conv_bn_pattern = _get_conv_bn_pattern(conv_fn) - match_pattern = get_aten_graph_module(conv_bn_pattern, example_inputs, is_cuda) + match_pattern = _get_aten_graph_module_for_pattern(conv_bn_pattern, example_inputs, is_cuda) # Step (1): Replace patterns with conv bias # @@ -555,7 +555,7 @@ def _fuse_conv_bn_qat_helper( # TODO: use the public replace_pattern API once it also returns replacement nodes qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn) - replacement_pattern_with_conv_bias = get_aten_graph_module( + replacement_pattern_with_conv_bias = _get_aten_graph_module_for_pattern( qat_conv_bn_pattern, example_inputs, is_cuda, @@ -572,7 +572,7 @@ def _fuse_conv_bn_qat_helper( # Step (2): Replace patterns without conv bias qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn) - replacement_pattern_no_conv_bias = get_aten_graph_module( + replacement_pattern_no_conv_bias = _get_aten_graph_module_for_pattern( qat_conv_bn_pattern_no_conv_bias, example_inputs, is_cuda, @@ -738,11 +738,11 @@ def _fold_conv_bn_qat_helper( match_pattern = _get_quantized_qat_conv_bn_pattern( is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training ) - match_pattern = get_aten_graph_module(match_pattern, example_inputs, is_cuda, **kwargs) + match_pattern = _get_aten_graph_module_for_pattern(match_pattern, example_inputs, is_cuda, **kwargs) replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern( is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training ) - replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, is_cuda, **kwargs) + replacement_pattern = _get_aten_graph_module_for_pattern(replacement_pattern, example_inputs, is_cuda, **kwargs) replacements.extend( replace_pattern_with_filters( m, diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py index 36ef2ecbdcdc1..7f5cb2eeb13b8 100644 --- a/torch/ao/quantization/pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -2,7 +2,7 @@ from torch.fx import GraphModule from ..export_utils import _WrapperModule from ..utils import ( - get_aten_graph_module, + _get_aten_graph_module_for_pattern, remove_tensor_overload_for_qdq_ops, _replace_literals_with_new_placeholders, _replace_literals_with_existing_placeholders, @@ -586,9 +586,9 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: replacement = rewrite_info.replacement pattern_post_trans = rewrite_info.pattern_post_trans replacement_post_trans = rewrite_info.replacement_post_trans - pattern = get_aten_graph_module(pattern, example_inputs) # type: ignore[arg-type, assignment] + pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment] remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type] - replacement = get_aten_graph_module(replacement, example_inputs) # type: ignore[arg-type, assignment] + replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment] remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type] if pattern_post_trans: pattern = pattern_post_trans(pattern) diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 70e4662be2c54..c80a7072e9069 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -20,7 +20,7 @@ __all__ = [ "fold_bn_weights_into_conv_node", - "get_aten_graph_module", + "_get_aten_graph_module_for_pattern", "remove_tensor_overload_for_qdq_ops", ] @@ -292,7 +292,7 @@ def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]: node_name_to_scope[n.name] = current_scope return node_name_to_scope -def get_aten_graph_module( +def _get_aten_graph_module_for_pattern( pattern: Callable, example_inputs: Tuple[Any, ...], is_cuda: bool = False, @@ -310,6 +310,16 @@ def get_aten_graph_module( ) aten_pattern.graph.eliminate_dead_code() aten_pattern.recompile() + + # ep.module() adds copy_ nodes for the mutated inputs. + # For patterns, it doesn't matter + for node in aten_pattern.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.copy_.default and len(node.users) == 0: + aten_pattern.graph.erase_node(node) + + aten_pattern.graph.eliminate_dead_code() + aten_pattern.recompile() + return aten_pattern def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None: @@ -370,8 +380,8 @@ def replacement(self, x): return x - 3 example_inputs = (torch.randn(1, 3, 3, 3),) - pattern_gm = get_aten_graph_module(pattern, example_inputs) - replacement_gm = get_aten_graph_module(pattern, example_inptus) + pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs) + replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus) # 2. Before calling replace literals we'll see the following graph: def pattern(self, x): @@ -456,8 +466,8 @@ def replacement(x_i8, scale, zero_point, quant_min, quant_max): -128, 127, ) - pattern_gm = get_aten_graph_module(pattern, example_inputs) - replacement_gm = get_aten_graph_module(pattern, example_inptus) + pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs) + replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus) # 2. Before calling replace literals we'll see the following graph: def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index 042163705a0b9..d08fa49c2f4ee 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -12,7 +12,7 @@ from torch.ao.quantization.pt2e.utils import ( _conv1d_bn_example_inputs, _conv2d_bn_example_inputs, - get_aten_graph_module, + _get_aten_graph_module_for_pattern, ) from torch.ao.quantization.quantizer import ( QuantizationAnnotation, @@ -469,7 +469,7 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): # Match against all conv dimensions and cuda variants for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: pattern = get_pattern(conv_fn, relu_is_inplace) - pattern = get_aten_graph_module(pattern, example_inputs, is_cuda) + pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) pattern.graph.eliminate_dead_code() pattern.recompile() matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)