Skip to content

Commit

Permalink
fix dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms committed Oct 23, 2024
1 parent 0da7a69 commit 0c9c48e
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,42 @@ def aten_arange_start_step(
"""arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

if dtype == -1:
result = op.Range(start, end, step)
# TODO: Because this is a trace_only function, the inputs are not promoted to
# Tensor until it hits ONNX ops. However, if it's dynamic, it should be
# Tensor at this point.
# https://github.com/microsoft/onnxscript/issues/1914
if isinstance(start, (int, float)):
start_is_int = isinstance(start, int)
else:
start_is_int = start.dtype in {

Check warning on line 642 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L642

Added line #L642 was not covered by tests
INT16.dtype,
INT32.dtype,
INT64.dtype,
}
if isinstance(end, (int, float)):
end_is_int = isinstance(end, int)
else:
end_is_int = end.dtype in {

Check warning on line 650 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L650

Added line #L650 was not covered by tests
INT16.dtype,
INT32.dtype,
INT64.dtype,
}
if isinstance(step, (int, float)):
step_is_int = isinstance(step, int)
else:
step_is_int = step.dtype in {

Check warning on line 658 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L658

Added line #L658 was not covered by tests
INT16.dtype,
INT32.dtype,
INT64.dtype,
}
if start_is_int and end_is_int and step_is_int:
result = op.Range(start, end, step)
else:
# to float
start = op.Cast(start, to=FLOAT.dtype)
end = op.Cast(end, to=FLOAT.dtype)
step = op.Cast(step, to=FLOAT.dtype)

Check warning on line 669 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L667-L669

Added lines #L667 - L669 were not covered by tests
result = op.Range(start, end, step)
elif _integral_to_be_adjusted(dtype):
# PyTorch arange op handles these integral types differently from INT64,
# so we have to adjust these arguments accordingly.
Expand Down

0 comments on commit 0c9c48e

Please sign in to comment.