-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix implementation of index_put, validate it with dort | fix(torchlib) #1277
Fix implementation of index_put, validate it with dort | fix(torchlib) #1277
Conversation
Signed-off-by: Xavier Dupre <[email protected]>
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1277 +/- ##
==========================================
+ Coverage 78.68% 78.89% +0.20%
==========================================
Files 119 119
Lines 15762 15809 +47
Branches 2486 2498 +12
==========================================
+ Hits 12403 12473 +70
+ Misses 2950 2919 -31
- Partials 409 417 +8 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Xavier Dupre <[email protected]>
Signed-off-by: Xavier Dupre <[email protected]>
Signed-off-by: Xavier Dupre <[email protected]>
Signed-off-by: Xavier Dupre <[email protected]>
# onnxruntime: MLFloat16 data type is not supported with ScatterND when reduction is 'add' | ||
and (sample.args[1].dtype != torch.float16 or not sample.kwargs.get("accumulate", False)) | ||
), | ||
reason="this Aten overload only support tensor(int) as indices or float16 when reduction is 'add'", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this matcher with the original code and it shows no different in results (still 4 passes and 20 skipped). It seems that you found a corner case that is not in the current op_DB_test
? If that's the case, we usually use https://github.com/microsoft/onnxscript/blob/main/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py to create corner case to test the implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did. I hope I did it right.
Please also update the PR description so that torchlib team: @justinchuby @xiaowuhu @fatcat-z can have more context later. Thanks! |
Signed-off-by: Xavier Dupre <[email protected]>
Signed-off-by: xadupre <[email protected]>
Little changes |
), | ||
TorchLibOpInfo( | ||
"index_put", | ||
core_ops.aten_index_put, | ||
).skip( | ||
matcher=lambda sample: not (sample.args[0][0].dtype == torch.int64), | ||
reason="this Aten overload only support tensor(int) as args", | ||
enabled_if=version_utils.onnxruntime_older_than("1.17"), |
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning test
), | ||
TorchLibOpInfo( | ||
"index_put", | ||
core_ops.aten_index_put, | ||
).skip( | ||
matcher=lambda sample: not (sample.args[0][0].dtype == torch.int64), | ||
reason="this Aten overload only support tensor(int) as args", | ||
enabled_if=version_utils.onnxruntime_older_than("1.17"), |
Check warning
Code scanning / lintrunner
RUFF/W291 Warning test
See https://docs.astral.sh/ruff/rules/trailing-whitespace
1c1b9a5
to
0de0d30
Compare
Merging into a microsoft branch so that it is easier to work on. |
286c70e
into
microsoft:justinchu/new-release
Continuation of #1277 by @xadupre The onnx implementation of index_put is different in torch script exporter (https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212). The PR replaces the current implementation failing on one corner case by the one from torch script. --------- Signed-off-by: Xavier Dupre <[email protected]> Signed-off-by: xadupre <[email protected]> Co-authored-by: Xavier Dupré <[email protected]>
The onnx implementation of index_put is different in torch script exporter (https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212). The PR replaces the current implementation failing on one corner case by the one from torch script.