Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix implementation of index_put, validate it with dort | fix(torchlib) #1277

Merged
merged 14 commits into from
Mar 13, 2024
Merged
27 changes: 11 additions & 16 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4044,27 +4044,22 @@ def aten_index_put(
values: TReal,
accumulate: bool = False,
) -> TReal:
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
See implementation of `torch.onnx.symbolic_opset11.index_put
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
"""

index = op.SequenceAt(indices, 0) # assume indices only have 1 element
# change array([1,3]) to array([[1,1,1,1,1],[3,3,3,3,3]])
self_dim_1 = op.Gather(op.Shape(self), 1)
index_dim_0 = op.Gather(op.Shape(index), 0)
neg_1 = op.Constant(value_ints=[-1])
shape = op.Concat(op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0)
new_ind = op.Expand(index, shape)
new_ind_t = op.Transpose(new_ind)
index = op.SequenceAt(indices, 0)
xadupre marked this conversation as resolved.
Show resolved Hide resolved
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
new_index = op.Unsqueeze(index, [-1])
shape_self = op.Shape(self)

if op.Cast(accumulate, to=BOOL.dtype):
# put values into zeros array first, then add to input
zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self))
zeros = op.CastLike(zeros, values)
result = op.ScatterElements(zeros, new_ind_t, values)
# FIXME: type promotion
result = op.CastLike(result, self)
zeros = op.CastLike(op.ConstantOfShape(shape_self), values)
result = op.ScatterND(zeros, new_index, values, reduction="add")
result = op.Add(result, self)
else:
result = op.ScatterElements(self, new_ind_t, values)
result = op.ScatterND(self, new_index, values)

return result


Expand Down
27 changes: 27 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,26 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
yield opinfo_core.SampleInput(make_arg((s, s, s, s)), args=args)


def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

data = torch_testing.make_tensor(
(10, 3),
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
indices = torch.arange(8, dtype=torch.int64, device=device).reshape((-1, 4))
values = torch_testing.make_tensor(
(2, 3, 4),
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
yield opinfo_core.SampleInput(data, indices, values)


def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs):
del op_info # unused
del kwargs
Expand Down Expand Up @@ -1933,6 +1953,13 @@ def __init__(self):
),
sample_inputs_func=sample_inputs_index,
),
opinfo_core.OpInfo(
"ops.aten.index_put",
aten_name="index_put",
dtypes=common_dtype.floating_types(),
sample_inputs_func=sample_inputs_index_put,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.layer_norm",
aten_name="layer_norm",
Expand Down
15 changes: 12 additions & 3 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,15 +836,24 @@ def _where_input_wrangler(
core_ops.aten_index_put_bool,
).skip(
matcher=lambda sample: not (sample.args[0][0].dtype == torch.bool),
reason="this Aten overload only support tensor(bool) as args",
reason="this Aten overload only support tensor(bool) as indices",
),
TorchLibOpInfo(
"index_put",
core_ops.aten_index_put,
).skip(
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
matcher=lambda sample: not (sample.args[0][0].dtype == torch.int64),
reason="this Aten overload only support tensor(int) as args",
enabled_if=version_utils.onnxruntime_older_than("1.16"),
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
matcher=lambda sample: not (
(sample.args[0][0].dtype == torch.int64)
# onnxruntime: MLFloat16 data type is not supported with ScatterND when reduction is 'add'
and (
sample.args[1].dtype != torch.float16
or not sample.kwargs.get("accumulate", False)
)
),
reason="this Aten overload only support tensor(int) as indices and float32 when accumulate is True",
),
TorchLibOpInfo("ops.aten.index_put", core_ops.aten_index_put),
TorchLibOpInfo("index_select", core_ops.aten_index_select),
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
TorchLibOpInfo("isclose", core_ops.aten_isclose),
TorchLibOpInfo("isfinite", core_ops.aten_isfinite),
Expand Down