Skip to content

Commit

Permalink
Update get_aten_graph_module (pytorch#121937)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#121937
Approved by: https://github.com/andrewor14
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Mar 15, 2024
1 parent af86d67 commit 53d2188
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 29 deletions.
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@
"reference_representation_rewrite",
# torch.ao.quantization.pt2e.utils
"fold_bn_weights_into_conv_node",
"get_aten_graph_module",
"remove_tensor_overload_for_qdq_ops",
# torch.ao.quantization.qconfig
"get_default_qat_qconfig",
Expand Down
1 change: 0 additions & 1 deletion test/allowlist_for_publicAPI.json
Original file line number Diff line number Diff line change
Expand Up @@ -1519,7 +1519,6 @@
"SharedQuantizationSpec",
"Tuple",
"fold_bn_weights_into_conv_node",
"get_aten_graph_module",
"replace_pattern_with_filters"
],
"torch.ao.quantization.quantize_fx": [
Expand Down
24 changes: 14 additions & 10 deletions torch/ao/quantization/pt2e/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
See https://github.com/pytorch/pytorch/issues/103681.
"""
# Avoid circular dependencies
from .utils import get_aten_graph_module
from .utils import _get_aten_graph_module_for_pattern

# Needed to ensure subgraph matches are self-contained
m.graph.eliminate_dead_code()
Expand All @@ -62,17 +62,17 @@ def dropout_eval(x):

example_inputs = (torch.randn(1),)
if train_to_eval:
match_pattern = get_aten_graph_module(
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_train), example_inputs
)
replacement_pattern = get_aten_graph_module(
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_eval), example_inputs
)
else:
match_pattern = get_aten_graph_module(
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_eval), example_inputs
)
replacement_pattern = get_aten_graph_module(
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(dropout_train), example_inputs
)

Expand Down Expand Up @@ -101,7 +101,7 @@ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
# Enable this support in future updates.

# Avoid circular dependencies
from .utils import get_aten_graph_module
from .utils import _get_aten_graph_module_for_pattern

# Needed to ensure subgraph matches are self-contained
m.graph.eliminate_dead_code()
Expand Down Expand Up @@ -137,13 +137,17 @@ def bn_eval(
torch.randn(1), # bn_running_var
)
if train_to_eval:
match_pattern = get_aten_graph_module(_WrapperModule(bn_train), example_inputs)
replacement_pattern = get_aten_graph_module(
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_train), example_inputs
)
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_eval), example_inputs
)
else:
match_pattern = get_aten_graph_module(_WrapperModule(bn_eval), example_inputs)
replacement_pattern = get_aten_graph_module(
match_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_eval), example_inputs
)
replacement_pattern = _get_aten_graph_module_for_pattern(
_WrapperModule(bn_train), example_inputs
)

Expand Down
12 changes: 6 additions & 6 deletions torch/ao/quantization/pt2e/qat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
_is_conv,
_is_bn_node,
fold_bn_weights_into_conv_node,
get_aten_graph_module,
_get_aten_graph_module_for_pattern,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -546,7 +546,7 @@ def _fuse_conv_bn_qat_helper(
m.graph.eliminate_dead_code()
m.recompile()
conv_bn_pattern = _get_conv_bn_pattern(conv_fn)
match_pattern = get_aten_graph_module(conv_bn_pattern, example_inputs, is_cuda)
match_pattern = _get_aten_graph_module_for_pattern(conv_bn_pattern, example_inputs, is_cuda)

# Step (1): Replace patterns with conv bias
#
Expand All @@ -555,7 +555,7 @@ def _fuse_conv_bn_qat_helper(
# TODO: use the public replace_pattern API once it also returns replacement nodes

qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn)
replacement_pattern_with_conv_bias = get_aten_graph_module(
replacement_pattern_with_conv_bias = _get_aten_graph_module_for_pattern(
qat_conv_bn_pattern,
example_inputs,
is_cuda,
Expand All @@ -572,7 +572,7 @@ def _fuse_conv_bn_qat_helper(
# Step (2): Replace patterns without conv bias

qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn)
replacement_pattern_no_conv_bias = get_aten_graph_module(
replacement_pattern_no_conv_bias = _get_aten_graph_module_for_pattern(
qat_conv_bn_pattern_no_conv_bias,
example_inputs,
is_cuda,
Expand Down Expand Up @@ -738,11 +738,11 @@ def _fold_conv_bn_qat_helper(
match_pattern = _get_quantized_qat_conv_bn_pattern(
is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
)
match_pattern = get_aten_graph_module(match_pattern, example_inputs, is_cuda, **kwargs)
match_pattern = _get_aten_graph_module_for_pattern(match_pattern, example_inputs, is_cuda, **kwargs)
replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
)
replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, is_cuda, **kwargs)
replacement_pattern = _get_aten_graph_module_for_pattern(replacement_pattern, example_inputs, is_cuda, **kwargs)
replacements.extend(
replace_pattern_with_filters(
m,
Expand Down
6 changes: 3 additions & 3 deletions torch/ao/quantization/pt2e/representation/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.fx import GraphModule
from ..export_utils import _WrapperModule
from ..utils import (
get_aten_graph_module,
_get_aten_graph_module_for_pattern,
remove_tensor_overload_for_qdq_ops,
_replace_literals_with_new_placeholders,
_replace_literals_with_existing_placeholders,
Expand Down Expand Up @@ -586,9 +586,9 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
replacement = rewrite_info.replacement
pattern_post_trans = rewrite_info.pattern_post_trans
replacement_post_trans = rewrite_info.replacement_post_trans
pattern = get_aten_graph_module(pattern, example_inputs) # type: ignore[arg-type, assignment]
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment]
remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
replacement = get_aten_graph_module(replacement, example_inputs) # type: ignore[arg-type, assignment]
replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment]
remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
if pattern_post_trans:
pattern = pattern_post_trans(pattern)
Expand Down
22 changes: 16 additions & 6 deletions torch/ao/quantization/pt2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

__all__ = [
"fold_bn_weights_into_conv_node",
"get_aten_graph_module",
"_get_aten_graph_module_for_pattern",
"remove_tensor_overload_for_qdq_ops",
]

Expand Down Expand Up @@ -292,7 +292,7 @@ def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
node_name_to_scope[n.name] = current_scope
return node_name_to_scope

def get_aten_graph_module(
def _get_aten_graph_module_for_pattern(
pattern: Callable,
example_inputs: Tuple[Any, ...],
is_cuda: bool = False,
Expand All @@ -310,6 +310,16 @@ def get_aten_graph_module(
)
aten_pattern.graph.eliminate_dead_code()
aten_pattern.recompile()

# ep.module() adds copy_ nodes for the mutated inputs.
# For patterns, it doesn't matter
for node in aten_pattern.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.copy_.default and len(node.users) == 0:
aten_pattern.graph.erase_node(node)

aten_pattern.graph.eliminate_dead_code()
aten_pattern.recompile()

return aten_pattern

def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
Expand Down Expand Up @@ -370,8 +380,8 @@ def replacement(self, x):
return x - 3
example_inputs = (torch.randn(1, 3, 3, 3),)
pattern_gm = get_aten_graph_module(pattern, example_inputs)
replacement_gm = get_aten_graph_module(pattern, example_inptus)
pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
# 2. Before calling replace literals we'll see the following graph:
def pattern(self, x):
Expand Down Expand Up @@ -456,8 +466,8 @@ def replacement(x_i8, scale, zero_point, quant_min, quant_max):
-128,
127,
)
pattern_gm = get_aten_graph_module(pattern, example_inputs)
replacement_gm = get_aten_graph_module(pattern, example_inptus)
pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
# 2. Before calling replace literals we'll see the following graph:
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.ao.quantization.pt2e.utils import (
_conv1d_bn_example_inputs,
_conv2d_bn_example_inputs,
get_aten_graph_module,
_get_aten_graph_module_for_pattern,
)
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
Expand Down Expand Up @@ -469,7 +469,7 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
# Match against all conv dimensions and cuda variants
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
pattern = get_pattern(conv_fn, relu_is_inplace)
pattern = get_aten_graph_module(pattern, example_inputs, is_cuda)
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda)
pattern.graph.eliminate_dead_code()
pattern.recompile()
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
Expand Down

0 comments on commit 53d2188

Please sign in to comment.