From 0a1b3be2163ea99633f95c4927bd816eb713e9bd Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 18 Mar 2024 19:19:26 +0000 Subject: [PATCH] chore: add unit test to verify split_by_tags output_type (#121262) Add a test case as per https://github.com/pytorch/pytorch/pull/120361#issuecomment-1979163324 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121262 Approved by: https://github.com/atalman --- test/fx/test_fx_split.py | 73 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/test/fx/test_fx_split.py b/test/fx/test_fx_split.py index ff776fe1adcde..fa8910ed61e2b 100644 --- a/test/fx/test_fx_split.py +++ b/test/fx/test_fx_split.py @@ -150,3 +150,76 @@ def test_split_by_tags(self) -> None: }, f"{orig_to_split_fqn_mapping=}", ) + +class TestSplitOutputType(TestCase): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv = self.conv(x) + conv = conv * 0.5 + relu = self.relu(conv) + return relu + + @staticmethod + def trace_and_tag( + module: torch.nn.Module, inputs: torch.Tensor, tags: List[str] + ) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]: + """ + Test simple gm consists of nodes with tag (only show call_module nodes here): + conv - tag: "red" + mul - tag: "blue" + relu - tag: "green" + + At the beginning we have: + gm: + conv + mul + relu + + split_gm = split_by_tags(gm, tags) + + Then we have: + split_gm: + red: + conv + blue: + mul + green: + relu + """ + tag_node = defaultdict(list) + gm: torch.fx.GraphModule = torch.export.export(module, (inputs,)).module() + # Add tag to all nodes and build dictionary record tag to call_module nodes + for node in gm.graph.nodes: + if "conv" in node.name: + node.tag = tags[0] + tag_node[tags[0]].append(node.name) + elif "mul" in node.name: + node.tag = tags[1] + tag_node[tags[1]].append(node.name) + else: + node.tag = tags[2] + if node.op == "call_module": + tag_node[tags[2]].append(node.name) + return gm, tag_node + + def test_split_by_tags(self) -> None: + tags = ["red", "blue", "green"] + module = TestSplitOutputType.TestModule() + + inputs = torch.randn((1, 3, 224, 224)) + + gm, tag_node = TestSplitOutputType.trace_and_tag(module, inputs, tags) + split_gm, orig_to_split_fqn_mapping = split_by_tags( + gm, tags, return_fqn_mapping=True + ) + + gm_output = module(inputs) + split_gm_output = split_gm(inputs) + + self.assertTrue(type(gm_output) == type(split_gm_output)) + self.assertTrue(torch.equal(gm_output, split_gm_output))