Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ATen complex and polar | feat(torchlib) #1286

Merged
merged 1 commit into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,10 +1641,23 @@ def aten_combinations(
raise NotImplementedError()


def aten_complex(real: TensorType, imag: TensorType) -> TensorType:
@torch_op("aten::complex", private=True)
def _aten_complex(real: TFloat, imag: TFloat) -> TFloat:
"""Non-broadcasting complex constructor."""

return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1)


@torch_op("aten::complex", trace_only=True)
def aten_complex(real: TFloat, imag: TFloat) -> TFloat:
"""complex(Tensor real, Tensor imag) -> Tensor"""

raise NotImplementedError()
# Broadcast the real and imaginary parts to the same shape
broadcasted_shape = _shape_of_broadcast_tensors(real, imag)
real = op.Expand(real, broadcasted_shape)
imag = op.Expand(imag, broadcasted_shape)

return _aten_complex(real, imag)


@torch_op("aten::concat")
Expand Down Expand Up @@ -6385,10 +6398,13 @@ def aten_poisson_nll_loss(
raise NotImplementedError()


def aten_polar(abs: TensorType, angle: TensorType) -> TensorType:
@torch_op("aten::polar")
def aten_polar(abs: TFloat, angle: TFloat) -> TFloat:
"""polar(Tensor abs, Tensor angle) -> Tensor"""

raise NotImplementedError()
real = op.Unsqueeze(op.Mul(abs, op.Cos(angle)), axes=[-1])
imag = op.Unsqueeze(op.Mul(abs, op.Sin(angle)), axes=[-1])
return op.Concat(real, imag, axis=-1)


def aten_polygamma(n: int, self: TensorType) -> TensorType:
Expand Down
2 changes: 2 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,7 @@ def _where_input_wrangler(
reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492",
),
TorchLibOpInfo("clone", core_ops.aten_clone),
TorchLibOpInfo("complex", core_ops.aten_complex, trace_only=True),
TorchLibOpInfo("concat", core_ops.aten_concat).skip(
matcher=lambda sample: sample.input[0].equal(torch.tensor([])),
reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619",
Expand Down Expand Up @@ -1332,6 +1333,7 @@ def _where_input_wrangler(
input_wrangler=_permute_input_wrangler,
trace_only=True,
),
TorchLibOpInfo("polar", core_ops.aten_polar),
TorchLibOpInfo("pow", core_ops.aten_pow),
TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True),
TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True),
Expand Down
Loading