Skip to content

Commit

Permalink
chore: add unit test to verify split_by_tags output_type (pytorch#121262
Browse files Browse the repository at this point in the history
)

Add a test case as per pytorch#120361 (comment)

Pull Request resolved: pytorch#121262
Approved by: https://github.com/atalman
  • Loading branch information
peri044 authored and pytorchmergebot committed Mar 18, 2024
1 parent 676a771 commit 0a1b3be
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions test/fx/test_fx_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,76 @@ def test_split_by_tags(self) -> None:
},
f"{orig_to_split_fqn_mapping=}",
)

class TestSplitOutputType(TestCase):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
conv = self.conv(x)
conv = conv * 0.5
relu = self.relu(conv)
return relu

@staticmethod
def trace_and_tag(
module: torch.nn.Module, inputs: torch.Tensor, tags: List[str]
) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
"""
Test simple gm consists of nodes with tag (only show call_module nodes here):
conv - tag: "red"
mul - tag: "blue"
relu - tag: "green"
At the beginning we have:
gm:
conv
mul
relu
split_gm = split_by_tags(gm, tags)
Then we have:
split_gm:
red:
conv
blue:
mul
green:
relu
"""
tag_node = defaultdict(list)
gm: torch.fx.GraphModule = torch.export.export(module, (inputs,)).module()
# Add tag to all nodes and build dictionary record tag to call_module nodes
for node in gm.graph.nodes:
if "conv" in node.name:
node.tag = tags[0]
tag_node[tags[0]].append(node.name)
elif "mul" in node.name:
node.tag = tags[1]
tag_node[tags[1]].append(node.name)
else:
node.tag = tags[2]
if node.op == "call_module":
tag_node[tags[2]].append(node.name)
return gm, tag_node

def test_split_by_tags(self) -> None:
tags = ["red", "blue", "green"]
module = TestSplitOutputType.TestModule()

inputs = torch.randn((1, 3, 224, 224))

gm, tag_node = TestSplitOutputType.trace_and_tag(module, inputs, tags)
split_gm, orig_to_split_fqn_mapping = split_by_tags(
gm, tags, return_fqn_mapping=True
)

gm_output = module(inputs)
split_gm_output = split_gm(inputs)

self.assertTrue(type(gm_output) == type(split_gm_output))
self.assertTrue(torch.equal(gm_output, split_gm_output))

0 comments on commit 0a1b3be

Please sign in to comment.