Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor] add decompositions for aten.angle (pytorch#105609)
Fixes pytorch#105564. Added tests. CPU benchmarking result: Before decomposition: ``` [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] ===== Forward graph 0 ===== [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] <eval_with_key>.4 from /home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py:477 in wrapped class <lambda>(torch.nn.Module): [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] def forward(self, arg0_1: f32[100000]): [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/yidi/local/t.py:5, code: return torch.angle(x) [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] angle: f32[100000] = torch.ops.aten.angle.default(arg0_1); arg0_1 = None [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] return (angle,) [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] [2023-07-19 14:59:51,277] torch._functorch.aot_autograd.__aot_graphs: [INFO] eager: per-call time (us): 1069.2930221557617 compiled: per-call time (us): 742.4068450927734 ``` After decomposition: ``` [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] ===== Forward graph 0 ===== [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] <eval_with_key>.4 from /home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py:477 in wrapped class <lambda>(torch.nn.Module): [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] def forward(self, arg0_1: f32[100000]): [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] # File: /home/yidi/local/t.py:5, code: return torch.angle(x) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] lt: b8[100000] = torch.ops.aten.lt.Scalar(arg0_1, 0) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] scalar_tensor: f32[] = torch.ops.aten.scalar_tensor.default(0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(3.141592653589793, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] where: f32[100000] = torch.ops.aten.where.self(lt, scalar_tensor_1, scalar_tensor); lt = scalar_tensor_1 = scalar_tensor = None [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] isnan: b8[100000] = torch.ops.aten.isnan.default(arg0_1); arg0_1 = None [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] scalar_tensor_2: f32[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] scalar_tensor_3: f32[] = torch.ops.aten.scalar_tensor.default(nan, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] where_1: f32[100000] = torch.ops.aten.where.self(isnan, scalar_tensor_3, scalar_tensor_2); isnan = scalar_tensor_3 = scalar_tensor_2 = None [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] add: f32[100000] = torch.ops.aten.add.Tensor(where, where_1); where = where_1 = None [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] return (add,) [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] [2023-07-19 14:57:53,849] torch._functorch.aot_autograd.__aot_graphs: [INFO] eager: per-call time (us): 1228.0082702636719 compiled: per-call time (us): 83.6038589477539 ``` Pull Request resolved: pytorch#105609 Approved by: https://github.com/jansel
- Loading branch information