Skip to content

Commit

Permalink
Implement ATen complex and polar | feat(torchlib) (#1286)
Browse files Browse the repository at this point in the history
aten::complex has a broadcasting behavior which is implemented here.

**NOTE:** Optimizations should consider eliminating the `Expand` node
when the broadcasted shape is the same as the input shape.

Fixes pytorch/pytorch#121100
  • Loading branch information
justinchuby authored Mar 4, 2024
1 parent bbb9584 commit aded324
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
24 changes: 20 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,10 +1641,23 @@ def aten_combinations(
raise NotImplementedError()


def aten_complex(real: TensorType, imag: TensorType) -> TensorType:
@torch_op("aten::complex", private=True)
def _aten_complex(real: TFloat, imag: TFloat) -> TFloat:
"""Non-broadcasting complex constructor."""

return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1)


@torch_op("aten::complex", trace_only=True)
def aten_complex(real: TFloat, imag: TFloat) -> TFloat:
"""complex(Tensor real, Tensor imag) -> Tensor"""

raise NotImplementedError()
# Broadcast the real and imaginary parts to the same shape
broadcasted_shape = _shape_of_broadcast_tensors(real, imag)
real = op.Expand(real, broadcasted_shape)
imag = op.Expand(imag, broadcasted_shape)

return _aten_complex(real, imag)


@torch_op("aten::concat")
Expand Down Expand Up @@ -6385,10 +6398,13 @@ def aten_poisson_nll_loss(
raise NotImplementedError()


def aten_polar(abs: TensorType, angle: TensorType) -> TensorType:
@torch_op("aten::polar")
def aten_polar(abs: TFloat, angle: TFloat) -> TFloat:
"""polar(Tensor abs, Tensor angle) -> Tensor"""

raise NotImplementedError()
real = op.Unsqueeze(op.Mul(abs, op.Cos(angle)), axes=[-1])
imag = op.Unsqueeze(op.Mul(abs, op.Sin(angle)), axes=[-1])
return op.Concat(real, imag, axis=-1)


def aten_polygamma(n: int, self: TensorType) -> TensorType:
Expand Down
2 changes: 2 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,7 @@ def _where_input_wrangler(
reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492",
),
TorchLibOpInfo("clone", core_ops.aten_clone),
TorchLibOpInfo("complex", core_ops.aten_complex, trace_only=True),
TorchLibOpInfo("concat", core_ops.aten_concat).skip(
matcher=lambda sample: sample.input[0].equal(torch.tensor([])),
reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619",
Expand Down Expand Up @@ -1332,6 +1333,7 @@ def _where_input_wrangler(
input_wrangler=_permute_input_wrangler,
trace_only=True,
),
TorchLibOpInfo("polar", core_ops.aten_polar),
TorchLibOpInfo("pow", core_ops.aten_pow),
TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True),
TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True),
Expand Down

0 comments on commit aded324

Please sign in to comment.