Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unique op #1547

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
77 changes: 70 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8380,14 +8380,27 @@
) -> tuple[TensorType, TensorType]:
"""_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)"""

unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
unique_values, indices, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
# HACK: force indices to be in the graph so that it gets a name during optimization
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest removing all hacks. I will go fix what's necessary where the bug is. We are also moving to prefer trace_only=True for new functions so if you can include the flag in @torch_op that would be awesome.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be awesome. The hacks are definitely getting out of hand. I'll wait for that fix so that I can continue to test with this locally.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a short script handy that will reproduce the error?

Copy link
Author

@a-gardner1 a-gardner1 May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if __name__ == '__main__':
    import logging
    import torch
    import numpy as np
    import onnx
    import onnxruntime as ort
    for i in range(16):
        sorted = bool(i & 1)
        return_inverse = bool((i & 2) > 1)
        return_counts = bool((i & 4) > 1)
        dim = 0 if bool((i & 8) > 1) else None

        print(
            f"Testing sorted={sorted}, return_inverse={return_inverse}, return_counts={return_counts}, dim={dim}"
        )

        def test_function(
                x: torch.Tensor,
                s: bool = sorted,
                ri: bool = return_inverse,
                rc: bool = return_counts,
                d: int | None = dim) -> Any:
            result = torch.unique(
                x,
                sorted=s,
                return_inverse=ri,
                return_counts=rc,
                dim=d)
            return result

        onnx_program = torch.onnx.dynamo_export(
            test_function,
            torch.arange(10),
            export_options=torch.onnx.ExportOptions(
                dynamic_shapes=True,
                diagnostic_options=torch.onnx.DiagnosticOptions(
                    verbosity_level=logging.DEBUG)))
        onnx_program.save("torch_unique.onnx")
        onnx_inputs = onnx_program.adapt_torch_inputs_to_onnx(torch.arange(10))
        onnx_outputs = onnx_program(*onnx_inputs)
        loaded_onnx_program = onnx.load("torch_unique.onnx")
        onnx.checker.check_model(loaded_onnx_program)
        ort_session = ort.InferenceSession("torch_unique.onnx")
        inputs = np.random.randint(0, 10, 10)
        print(f"Inputs: {inputs}")
        outputs = ort_session.run(None,
                                  {"l_x_": inputs})
        print(f"Outputs: {outputs}")
    print("Success")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you should also test using the nightly release of PyTorch with the changes in pytorch/pytorch#126561.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is trace_only=True expected to require significant changes to the way one implements an op? It appears that enabling the flag breaks passing a value to op.ConstantOfShape and also breaks indexing a shape.

For example, op.ConstantOfShape([0], value=[0]) must become op.Cast(op.ConstantOfShape([0]), to=INT64.dtype), and output_size[dim] must become op.Slice(output_size, [dim], [dim+1]).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your observation is correct. This may be the case because the gaps in implementation we have. Bridging the gaps is in our roadmap but is not the highest priority for the team.

# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
# We don't need to worry about unique_values since it is a required output.
indices_size = op.Shape(indices)
indices_numel = op.ReduceProd(indices_size, keepdims=False)
inverse_indices_size = op.Shape(inverse_indices)
inverse_indices_numel = op.ReduceProd(inverse_indices_size, keepdims=False)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
input_size = op.Shape(self)
# force inverse_indices to depend on indices through input_size
if indices_numel != 0:
input_size = input_size * indices_numel
input_size = input_size / indices_numel
else:
input_size = input_size + indices_numel

Check warning on line 8397 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8397

Added line #L8397 was not covered by tests
if return_inverse:
inverse_indices = op.Reshape(inverse_indices, input_size)
else:
input_numel = op.ReduceProd(input_size, keepdims=False)
if input_numel == 0:
inverse_indices = op.Reshape(inverse_indices, input_size)

Check warning on line 8403 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8403

Added line #L8403 was not covered by tests
else:
inverse_indices = op.ConstantOfShape([0], value=[0])
return unique_values, inverse_indices
Expand All @@ -8403,22 +8416,46 @@
"""_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

unique_values, indices, inverse_indices, counts = op.Unique(self, axis=None, sorted=True)
# HACK: force indices and inverse_indices to be in the graph so
# that they get names during optimization.
# counts must depend on indices and inverse_indices,
# and inverse_indices must depend on indices
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
# We don't have to worry about unique_values because it is a required output.
indices_size = op.Shape(indices)
indices_numel = op.ReduceProd(indices_size, keepdims=False)
inverse_indices_size = op.Shape(inverse_indices)
inverse_indices_numel = op.ReduceProd(inverse_indices_size, keepdims=False)
input_size = op.Shape(self)
# force inverse_indices to depend on indices through input_size
if indices_numel != 0:
input_size = input_size * indices_numel
input_size = input_size / indices_numel
else:
input_size = input_size + indices_numel

Check warning on line 8435 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8435

Added line #L8435 was not covered by tests
if return_inverse:
inverse_indices = op.Reshape(inverse_indices, input_size)
else:
input_numel = op.ReduceProd(input_size, keepdims=False)
if input_numel == 0:
inverse_indices = op.Reshape(inverse_indices, input_size)

Check warning on line 8441 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8441

Added line #L8441 was not covered by tests
else:
inverse_indices = op.ConstantOfShape([0], value=[0])
if return_counts:
# HACK: force indices to be in the graph so that it gets a name during optimization
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
indices_size = op.Shape(indices)
# force counts to depend on inverse_indices through indices_size
if inverse_indices_numel != 0:
indices_size = indices_size * inverse_indices_numel
indices_size = indices_size / inverse_indices_numel
else:
indices_size = indices_size + inverse_indices_numel

Check warning on line 8450 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8450

Added line #L8450 was not covered by tests
# force counts to depend on indices
counts = op.Reshape(counts, indices_size)
else:
counts = op.ConstantOfShape([0], value=[0])
# force counts to depend on indices
counts = counts * indices_numel
# force counts to depend on inverse_indices
counts = counts * inverse_indices_numel
return unique_values, inverse_indices, counts


Expand All @@ -8429,26 +8466,52 @@
sorted: bool = True, # pylint: disable=unused-argument
return_inverse: bool = False,
return_counts: bool = False,
is_cuda: bool = False
Fixed Show fixed Hide fixed
) -> tuple[TensorType, TensorType, TensorType]:
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

unique_values, indices, inverse_indices, counts = op.Unique(self, axis=dim, sorted=True)

Check warning on line 8473 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8473

Added line #L8473 was not covered by tests
# HACK: force indices and inverse_indices to be in the graph so
# that they get names during optimization.
# counts must depend on indices and inverse_indices,
# and inverse_indices must depend on indices
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
# We don't have to worry about unique_values because it is a required output.
indices_size = op.Shape(indices)
indices_numel = op.ReduceProd(indices_size, keepdims=False)
inverse_indices_size = op.Shape(inverse_indices)
inverse_indices_numel = op.ReduceProd(inverse_indices_size, keepdims=False)

Check warning on line 8483 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8480-L8483

Added lines #L8480 - L8483 were not covered by tests
if return_inverse:
input_size = op.Shape(self)

Check warning on line 8485 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8485

Added line #L8485 was not covered by tests
# force inverse_indices to depend on indices through input_size
if indices_numel != 0:
input_size = input_size * indices_numel
input_size = input_size / indices_numel

Check warning on line 8489 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8488-L8489

Added lines #L8488 - L8489 were not covered by tests
else:
input_size = input_size + indices_numel
inverse_indices = op.Reshape(inverse_indices, op.Reshape(input_size[dim], [-1]))

Check warning on line 8492 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8491-L8492

Added lines #L8491 - L8492 were not covered by tests
else:
inverse_indices = op.ConstantOfShape([0], value=[0])

Check warning on line 8494 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8494

Added line #L8494 was not covered by tests
# force inverse_indices to depend on indices
inverse_indices = inverse_indices * indices_numel

Check warning on line 8496 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8496

Added line #L8496 was not covered by tests
if return_counts:
# HACK: force indices to be in the graph so that it gets a name during optimization
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
indices_size = op.Shape(indices)
# force dependence on inverse_indices through indices_size
if inverse_indices_numel != 0:
indices_size = indices_size * inverse_indices_numel
indices_size = indices_size / inverse_indices_numel

Check warning on line 8501 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8500-L8501

Added lines #L8500 - L8501 were not covered by tests
else:
indices_size = indices_size + inverse_indices_numel

Check warning on line 8503 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8503

Added line #L8503 was not covered by tests
# force dependence on indices
counts = op.Reshape(counts, indices_size)
output_size = op.Shape(unique_values)
counts = op.Reshape(counts, op.Reshape(output_size[dim], [-1]))

Check warning on line 8507 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8505-L8507

Added lines #L8505 - L8507 were not covered by tests
else:
counts = op.ConstantOfShape([0], value=[0])

Check warning on line 8509 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8509

Added line #L8509 was not covered by tests
# force dependence on indices
counts = counts * indices_numel

Check warning on line 8511 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8511

Added line #L8511 was not covered by tests
# force dependence on inverse_indices
counts = counts * inverse_indices_numel
return unique_values, inverse_indices, counts

Check warning on line 8514 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8513-L8514

Added lines #L8513 - L8514 were not covered by tests


def aten_unique_dim_consecutive(
Expand Down
Loading