Skip to content

Commit

Permalink
[inductor] add decompositions for aten.angle (pytorch#105609)
Browse files Browse the repository at this point in the history
Fixes pytorch#105564.

Added tests.

CPU benchmarking result:
Before decomposition:
```
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]  ===== Forward graph 0 =====
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]  <eval_with_key>.4 from /home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py:477 in wrapped class <lambda>(torch.nn.Module):
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]     def forward(self, arg0_1: f32[100000]):
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/yidi/local/t.py:5, code: return torch.angle(x)
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]         angle: f32[100000] = torch.ops.aten.angle.default(arg0_1);  arg0_1 = None
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]         return (angle,)
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]
[2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO]
eager:
per-call time (us): 1069.2930221557617
compiled:
per-call time (us): 742.4068450927734
```

After decomposition:
```
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]  ===== Forward graph 0 =====
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]  <eval_with_key>.4 from /home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py:477 in wrapped class <lambda>(torch.nn.Module):
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]     def forward(self, arg0_1: f32[100000]):
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/yidi/local/t.py:5, code: return torch.angle(x)
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         lt: b8[100000] = torch.ops.aten.lt.Scalar(arg0_1, 0)
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         scalar_tensor: f32[] = torch.ops.aten.scalar_tensor.default(0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(3.141592653589793, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         where: f32[100000] = torch.ops.aten.where.self(lt, scalar_tensor_1, scalar_tensor);  lt = scalar_tensor_1 = scalar_tensor = None
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         isnan: b8[100000] = torch.ops.aten.isnan.default(arg0_1);  arg0_1 = None
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         scalar_tensor_2: f32[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         scalar_tensor_3: f32[] = torch.ops.aten.scalar_tensor.default(nan, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         where_1: f32[100000] = torch.ops.aten.where.self(isnan, scalar_tensor_3, scalar_tensor_2);  isnan = scalar_tensor_3 = scalar_tensor_2 = None
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         add: f32[100000] = torch.ops.aten.add.Tensor(where, where_1);  where = where_1 = None
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]         return (add,)
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]
[2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO]
eager:
per-call time (us): 1228.0082702636719
compiled:
per-call time (us): 83.6038589477539
```

Pull Request resolved: pytorch#105609
Approved by: https://github.com/jansel
  • Loading branch information
ydwu4 authored and pytorchmergebot committed Jul 20, 2023
1 parent 9760ea5 commit 9584d61
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
11 changes: 11 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,17 @@ def fn(a):

self.common(fn, (torch.randn(17),))

def test_angle(self):
def fn(a, b, c):
return torch.angle(a), torch.angle(b), torch.angle(c)

complex_input = torch.tensor(
[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1, float("nan")]
)
real_input = torch.tensor([-1.0, 0.0, 1.0, float("nan")])
interger_real_input = torch.tensor([-1, 0, 1])
self.common(fn, (complex_input, real_input, interger_real_input))

def test_sgn(self):
def fn(a):
return torch.sgn(a), torch.sgn(a + 1) - 1
Expand Down
16 changes: 16 additions & 0 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,22 @@ def cat(tensors, dim=0):
return NotImplemented


@register_decomposition([aten.angle])
def angle(x):
if x.is_complex():
return torch.where(
torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real)
)
else:
# when x is real number
# if x >= 0, return 0
# if x < 0, return pi
# if x is nan, return nan
ret = torch.where(x < 0, math.pi, 0.0)
nan = torch.where(torch.isnan(x), float("nan"), 0.0)
return ret + nan


@register_decomposition([aten.conj_physical])
def conj_physical(self):
assert not self.is_complex(), "TODO: implement this"
Expand Down

0 comments on commit 9584d61

Please sign in to comment.