diff --git a/test/fx/test_fx_split.py b/test/fx/test_fx_split.py index ff776fe1adcde..fa8910ed61e2b 100644 --- a/test/fx/test_fx_split.py +++ b/test/fx/test_fx_split.py @@ -150,3 +150,76 @@ def test_split_by_tags(self) -> None: }, f"{orig_to_split_fqn_mapping=}", ) + +class TestSplitOutputType(TestCase): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv = self.conv(x) + conv = conv * 0.5 + relu = self.relu(conv) + return relu + + @staticmethod + def trace_and_tag( + module: torch.nn.Module, inputs: torch.Tensor, tags: List[str] + ) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]: + """ + Test simple gm consists of nodes with tag (only show call_module nodes here): + conv - tag: "red" + mul - tag: "blue" + relu - tag: "green" + + At the beginning we have: + gm: + conv + mul + relu + + split_gm = split_by_tags(gm, tags) + + Then we have: + split_gm: + red: + conv + blue: + mul + green: + relu + """ + tag_node = defaultdict(list) + gm: torch.fx.GraphModule = torch.export.export(module, (inputs,)).module() + # Add tag to all nodes and build dictionary record tag to call_module nodes + for node in gm.graph.nodes: + if "conv" in node.name: + node.tag = tags[0] + tag_node[tags[0]].append(node.name) + elif "mul" in node.name: + node.tag = tags[1] + tag_node[tags[1]].append(node.name) + else: + node.tag = tags[2] + if node.op == "call_module": + tag_node[tags[2]].append(node.name) + return gm, tag_node + + def test_split_by_tags(self) -> None: + tags = ["red", "blue", "green"] + module = TestSplitOutputType.TestModule() + + inputs = torch.randn((1, 3, 224, 224)) + + gm, tag_node = TestSplitOutputType.trace_and_tag(module, inputs, tags) + split_gm, orig_to_split_fqn_mapping = split_by_tags( + gm, tags, return_fqn_mapping=True + ) + + gm_output = module(inputs) + split_gm_output = split_gm(inputs) + + self.assertTrue(type(gm_output) == type(split_gm_output)) + self.assertTrue(torch.equal(gm_output, split_gm_output))