From 9584d614a1cd4c32c8dbe4487b2d2e81c3462d6a Mon Sep 17 00:00:00 2001 From: ydwu4 Date: Thu, 20 Jul 2023 19:12:17 +0000 Subject: [PATCH] [inductor] add decompositions for aten.angle (#105609) Fixes #105564. Added tests. CPU benchmarking result: Before decomposition: ``` [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] ===== Forward graph 0 ===== [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] .4 from /home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py:477 in wrapped class (torch.nn.Module): [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] def forward(self, arg0_1: f32[100000]): [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/yidi/local/t.py:5, code: return torch.angle(x) [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] angle: f32[100000] = torch.ops.aten.angle.default(arg0_1); arg0_1 = None [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] return (angle,) [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] eager: per-call time (us): 1069.2930221557617 compiled: per-call time (us): 742.4068450927734 ``` After decomposition: ``` [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] ===== Forward graph 0 ===== [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] .4 from /home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py:477 in wrapped class (torch.nn.Module): [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] def forward(self, arg0_1: f32[100000]): [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/yidi/local/t.py:5, code: return torch.angle(x) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] lt: b8[100000] = torch.ops.aten.lt.Scalar(arg0_1, 0) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] scalar_tensor: f32[] = torch.ops.aten.scalar_tensor.default(0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(3.141592653589793, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] where: f32[100000] = torch.ops.aten.where.self(lt, scalar_tensor_1, scalar_tensor); lt = scalar_tensor_1 = scalar_tensor = None [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] isnan: b8[100000] = torch.ops.aten.isnan.default(arg0_1); arg0_1 = None [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] scalar_tensor_2: f32[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] scalar_tensor_3: f32[] = torch.ops.aten.scalar_tensor.default(nan, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] where_1: f32[100000] = torch.ops.aten.where.self(isnan, scalar_tensor_3, scalar_tensor_2); isnan = scalar_tensor_3 = scalar_tensor_2 = None [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] add: f32[100000] = torch.ops.aten.add.Tensor(where, where_1); where = where_1 = None [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] return (add,) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] eager: per-call time (us): 1228.0082702636719 compiled: per-call time (us): 83.6038589477539 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/105609 Approved by: https://github.com/jansel --- test/inductor/test_torchinductor.py | 11 +++++++++++ torch/_inductor/decomposition.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 051bc70207cf0..47c9eab5ba70b 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -581,6 +581,17 @@ def fn(a): self.common(fn, (torch.randn(17),)) + def test_angle(self): + def fn(a, b, c): + return torch.angle(a), torch.angle(b), torch.angle(c) + + complex_input = torch.tensor( + [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1, float("nan")] + ) + real_input = torch.tensor([-1.0, 0.0, 1.0, float("nan")]) + interger_real_input = torch.tensor([-1, 0, 1]) + self.common(fn, (complex_input, real_input, interger_real_input)) + def test_sgn(self): def fn(a): return torch.sgn(a), torch.sgn(a + 1) - 1 diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 93d90a52a896b..fcb80f303baf2 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -169,6 +169,22 @@ def cat(tensors, dim=0): return NotImplemented +@register_decomposition([aten.angle]) +def angle(x): + if x.is_complex(): + return torch.where( + torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real) + ) + else: + # when x is real number + # if x >= 0, return 0 + # if x < 0, return pi + # if x is nan, return nan + ret = torch.where(x < 0, math.pi, 0.0) + nan = torch.where(torch.isnan(x), float("nan"), 0.0) + return ret + nan + + @register_decomposition([aten.conj_physical]) def conj_physical(self): assert not self.is_complex(), "TODO: implement this"