Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix implementation of index_put, validate it with dort | fix(torchlib) #1277

Merged
merged 14 commits into from
Mar 13, 2024

Conversation

xadupre
Copy link
Member

@xadupre xadupre commented Feb 14, 2024

The onnx implementation of index_put is different in torch script exporter (https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212). The PR replaces the current implementation failing on one corner case by the one from torch script.

Copy link

codecov bot commented Feb 14, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 78.89%. Comparing base (ce3eb4a) to head (07fc1b5).
Report is 17 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1277      +/-   ##
==========================================
+ Coverage   78.68%   78.89%   +0.20%     
==========================================
  Files         119      119              
  Lines       15762    15809      +47     
  Branches     2486     2498      +12     
==========================================
+ Hits        12403    12473      +70     
+ Misses       2950     2919      -31     
- Partials      409      417       +8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Xavier Dupre <[email protected]>
Signed-off-by: Xavier Dupre <[email protected]>
Signed-off-by: Xavier Dupre <[email protected]>
@xadupre xadupre changed the title validate the translation of index_put with dort fix implementation of index_put, validate it with dort Feb 14, 2024
Signed-off-by: Xavier Dupre <[email protected]>
@titaiwangms titaiwangms self-requested a review February 14, 2024 17:05
@titaiwangms titaiwangms added the topic: torch_lib Related to the torch/aten function lib in development label Feb 14, 2024
# onnxruntime: MLFloat16 data type is not supported with ScatterND when reduction is 'add'
and (sample.args[1].dtype != torch.float16 or not sample.kwargs.get("accumulate", False))
),
reason="this Aten overload only support tensor(int) as indices or float16 when reduction is 'add'",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this matcher with the original code and it shows no different in results (still 4 passes and 20 skipped). It seems that you found a corner case that is not in the current op_DB_test? If that's the case, we usually use https://github.com/microsoft/onnxscript/blob/main/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py to create corner case to test the implementation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did. I hope I did it right.

@titaiwangms
Copy link
Contributor

Please also update the PR description so that torchlib team: @justinchuby @xiaowuhu @fatcat-z can have more context later. Thanks!

@justinchuby justinchuby changed the title fix implementation of index_put, validate it with dort Fix implementation of index_put, validate it with dort | fix(torchlib) Feb 15, 2024
Signed-off-by: Xavier Dupre <[email protected]>
@titaiwangms
Copy link
Contributor

titaiwangms commented Feb 20, 2024

Little changes got merge not merged in main, could be useful for us. Might be unrelated/related. pytorch/pytorch#110860

),
TorchLibOpInfo(
"index_put",
core_ops.aten_index_put,
).skip(
matcher=lambda sample: not (sample.args[0][0].dtype == torch.int64),
reason="this Aten overload only support tensor(int) as args",
enabled_if=version_utils.onnxruntime_older_than("1.17"),

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning test

Trailing whitespace
),
TorchLibOpInfo(
"index_put",
core_ops.aten_index_put,
).skip(
matcher=lambda sample: not (sample.args[0][0].dtype == torch.int64),
reason="this Aten overload only support tensor(int) as args",
enabled_if=version_utils.onnxruntime_older_than("1.17"),

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

@justinchuby justinchuby changed the base branch from main to justinchu/new-release March 13, 2024 02:00
@justinchuby
Copy link
Collaborator

Merging into a microsoft branch so that it is easier to work on.

@justinchuby justinchuby merged commit 286c70e into microsoft:justinchu/new-release Mar 13, 2024
11 of 36 checks passed
justinchuby added a commit that referenced this pull request Mar 21, 2024
Continuation of #1277 by @xadupre 

The onnx implementation of index_put is different in torch script
exporter

(https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212).
The PR replaces the current implementation failing on one corner case by
the one from torch script.

---------

Signed-off-by: Xavier Dupre <[email protected]>
Signed-off-by: xadupre <[email protected]>
Co-authored-by: Xavier Dupré <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: torch_lib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

4 participants