Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Add the identity nodes back #1703

Merged
merged 8 commits into from
Jun 25, 2024
Merged

Conversation

xadupre
Copy link
Member

@xadupre xadupre commented Jun 24, 2024

In the modularization pass in the exporter, a single node like clone can be lifted as a function. If we remove the only Identity node the lifted function will have no nodes. This violates the ONNX standard.

Since removing identity nodes is fast, we are safe to include these identity nodes in the torchlib.

onnxscript/tools/transformers_models/phi_test.py broke after #1613, it is fixed by this change.

Signed-off-by: Xavier Dupre <[email protected]>
Copy link

codecov bot commented Jun 24, 2024

Codecov Report

Attention: Patch coverage is 92.30769% with 2 lines in your changes missing coverage. Please review.

Project coverage is 76.28%. Comparing base (1aa7a70) to head (0a18177).

Files Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 91.66% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1703      +/-   ##
==========================================
+ Coverage   75.75%   76.28%   +0.53%     
==========================================
  Files         240      240              
  Lines       25495    25496       +1     
  Branches     4549     4550       +1     
==========================================
+ Hits        19313    19449     +136     
+ Misses       5281     5159     -122     
+ Partials      901      888      -13     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Xavier Dupre <[email protected]>
Signed-off-by: Xavier Dupre <[email protected]>
@xadupre xadupre marked this pull request as ready for review June 24, 2024 14:28
@xadupre xadupre changed the title [WIP] investigations Revert one change introduced by #1613 and breaking onnxscript/tools/transformers_models/phi_test.py Jun 24, 2024
Signed-off-by: Xavier Dupre <[email protected]>
@justinchuby
Copy link
Collaborator

Could you explain why the change is needed? Thanks!

@xadupre
Copy link
Member Author

xadupre commented Jun 24, 2024

Could you explain why the change is needed? Thanks!

It is only a guess. One onnx function was not added to the model. I suspect when one function has no node, it is not added to the final model. I was hoping you could confirm.

@justinchuby
Copy link
Collaborator

When trying to reproduce, I see the error happens in the optimizer:

justinchu@justinchu-dev-linux2 ~/d/onnxscript (main)> pytest onnxscript/tools/transformers_models/phi_tes
t.py
========================================== test session starts ==========================================
platform linux -- Python 3.11.9, pytest-8.2.1, pluggy-1.5.0
Using --randomly-seed=296786577
rootdir: /home/justinchu/dev/onnxscript
configfile: pyproject.toml
plugins: hypothesis-6.103.0, subtests-0.12.1, randomly-3.15.0, cov-5.0.0, xdist-3.6.1
collected 3 items                                                                                       

onnxscript/tools/transformers_models/phi_test.py Fs.                                              [100%]

=============================================== FAILURES ================================================
___________________________________ TestExportPhi.test_phi_export_cpu ___________________________________
onnxscript/tools/transformers_models/phi_test.py:28: in test_phi_export_cpu
    proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors)
onnxscript/tools/transformers_models/__init__.py:28: in export_to_onnx
    model_proto = onnxscript.optimizer.optimize(
onnxscript/optimizer/__init__.py:80: in optimize
    inline_functions_with_unused_outputs(model)
onnxscript/optimizer/simple_function_folding.py:231: in inline_functions_with_unused_outputs
    inliner.visit_model(model)
onnxscript/optimizer/simple_function_folding.py:33: in visit_model
    super().visit_model(model)
onnxscript/_legacy_ir/visitor.py:792: in visit_model
    self.visit_graph(model.graph)
onnxscript/_legacy_ir/visitor.py:658: in visit_graph
    replacement = self.visit_node(node)
onnxscript/_legacy_ir/visitor.py:805: in visit_node
    replacement, _ = self.process_function_node(node)
onnxscript/optimizer/simple_function_folding.py:44: in process_function_node
    replacement, new_function = super().process_function_node(node)
onnxscript/_legacy_ir/visitor.py:892: in process_function_node
    replacement = self.visit_node(inner_node)
onnxscript/_legacy_ir/visitor.py:805: in visit_node
    replacement, _ = self.process_function_node(node)
onnxscript/optimizer/simple_function_folding.py:44: in process_function_node
    replacement, new_function = super().process_function_node(node)
onnxscript/_legacy_ir/visitor.py:846: in process_function_node
    actual_input_value_infos = [self.lookup(input) for input in node.input]
onnxscript/_legacy_ir/visitor.py:846: in <listcomp>
    actual_input_value_infos = [self.lookup(input) for input in node.input]
onnxscript/_legacy_ir/visitor.py:447: in lookup
    raise ValueError(
E   ValueError: Undefined variable model_embed_dropout_1.
E   Available variables: SubScope 0:
E     Function transformers_models_phi_modeling_phi_PhiModel_model_1:
E       ir.Values:
E         l_input_ids_: StaticValueInfo(l_input_ids_, shape:[13, 7], dtype:7, no const value.)
E         l_attention_mask_: StaticValueInfo(l_attention_mask_, shape:[13, 7], dtype:1, no const value.)
E         model.embed_tokens.weight: StaticValueInfo(model.embed_tokens.weight, shape:[99, 32], dtype:1, has const value.)
E         model.layers.0.input_layernorm.weight: StaticValueInfo(model.layers.0.input_layernorm.weight, shape:[32], dtype:1, has const value.)
E         model.layers.0.input_layernorm.bias: StaticValueInfo(model.layers.0.input_layernorm.bias, shape:[32], dtype:1, has const value.)
E         model.layers.0.self_attn.q_proj.weight: StaticValueInfo(model.layers.0.self_attn.q_proj.weight, shape:[32, 32], dtype:1, has const value.)
E         model.layers.0.self_attn.q_proj.bias: StaticValueInfo(model.layers.0.self_attn.q_proj.bias, shape:[32], dtype:1, has const value.)
E         model.layers.0.self_attn.k_proj.weight: StaticValueInfo(model.layers.0.self_attn.k_proj.weight, shape:[16, 32], dtype:1, has const value.)
E         model.layers.0.self_attn.k_proj.bias: StaticValueInfo(model.layers.0.self_attn.k_proj.bias, shape:[16], dtype:1, has const value.)
E         model.layers.0.self_attn.v_proj.weight: StaticValueInfo(model.layers.0.self_attn.v_proj.weight, shape:[16, 32], dtype:1, has const value.)
E         model.layers.0.self_attn.v_proj.bias: StaticValueInfo(model.layers.0.self_attn.v_proj.bias, shape:[16], dtype:1, has const value.)
E         model.layers.0.self_attn.rotary_emb.cos_cached: StaticValueInfo(model.layers.0.self_attn.rotary_emb.cos_cached, shape:[512, 4], dtype:1, has const value.)
E         model.layers.0.self_attn.rotary_emb.sin_cached: StaticValueInfo(model.layers.0.self_attn.rotary_emb.sin_cached, shape:[512, 4], dtype:1, has const value.)
E         model.layers.0.self_attn.dense.weight: StaticValueInfo(model.layers.0.self_attn.dense.weight, shape:[32, 32], dtype:1, has const value.)
E         model.layers.0.self_attn.dense.bias: StaticValueInfo(model.layers.0.self_attn.dense.bias, shape:[32], dtype:1, has const value.)
E         model.layers.0.mlp.fc1.weight: StaticValueInfo(model.layers.0.mlp.fc1.weight, shape:[16, 32], dtype:1, has const value.)
E         model.layers.0.mlp.fc1.bias: StaticValueInfo(model.layers.0.mlp.fc1.bias, shape:[16], dtype:1, has const value.)
E         model.layers.0.mlp.fc2.weight: StaticValueInfo(model.layers.0.mlp.fc2.weight, shape:[32, 16], dtype:1, has const value.)
E         model.layers.0.mlp.fc2.bias: StaticValueInfo(model.layers.0.mlp.fc2.bias, shape:[32], dtype:1, has const value.)
E         model.layers.1.input_layernorm.weight: StaticValueInfo(model.layers.1.input_layernorm.weight, shape:[32], dtype:1, has const value.)
E         model.layers.1.input_layernorm.bias: StaticValueInfo(model.layers.1.input_layernorm.bias, shape:[32], dtype:1, has const value.)
E         model.layers.1.self_attn.q_proj.weight: StaticValueInfo(model.layers.1.self_attn.q_proj.weight, shape:[32, 32], dtype:1, has const value.)
E         model.layers.1.self_attn.q_proj.bias: StaticValueInfo(model.layers.1.self_attn.q_proj.bias, shape:[32], dtype:1, has const value.)
E         model.layers.1.self_attn.k_proj.weight: StaticValueInfo(model.layers.1.self_attn.k_proj.weight, shape:[16, 32], dtype:1, has const value.)
E         model.layers.1.self_attn.k_proj.bias: StaticValueInfo(model.layers.1.self_attn.k_proj.bias, shape:[16], dtype:1, has const value.)
E         model.layers.1.self_attn.v_proj.weight: StaticValueInfo(model.layers.1.self_attn.v_proj.weight, shape:[16, 32], dtype:1, has const value.)
E         model.layers.1.self_attn.v_proj.bias: StaticValueInfo(model.layers.1.self_attn.v_proj.bias, shape:[16], dtype:1, has const value.)
E         model.layers.1.self_attn.rotary_emb.cos_cached: StaticValueInfo(model.layers.1.self_attn.rotary_emb.cos_cached, shape:[512, 4], dtype:1, has const value.)
E         model.layers.1.self_attn.rotary_emb.sin_cached: StaticValueInfo(model.layers.1.self_attn.rotary_emb.sin_cached, shape:[512, 4], dtype:1, has const value.)
E         model.layers.1.self_attn.dense.weight: StaticValueInfo(model.layers.1.self_attn.dense.weight, shape:[32, 32], dtype:1, has const value.)
E         model.layers.1.self_attn.dense.bias: StaticValueInfo(model.layers.1.self_attn.dense.bias, shape:[32], dtype:1, has const value.)
E         model.layers.1.mlp.fc1.weight: StaticValueInfo(model.layers.1.mlp.fc1.weight, shape:[16, 32], dtype:1, has const value.)
E         model.layers.1.mlp.fc1.bias: StaticValueInfo(model.layers.1.mlp.fc1.bias, shape:[16], dtype:1, has const value.)
E         model.layers.1.mlp.fc2.weight: StaticValueInfo(model.layers.1.mlp.fc2.weight, shape:[32, 16], dtype:1, has const value.)
E         model.layers.1.mlp.fc2.bias: StaticValueInfo(model.layers.1.mlp.fc2.bias, shape:[32], dtype:1, has const value.)
E         model.final_layernorm.weight: StaticValueInfo(model.final_layernorm.weight, shape:[32], dtype:1, has const value.)
E         model.final_layernorm.bias: StaticValueInfo(model.final_layernorm.bias, shape:[32], dtype:1, has const value.)
E         unsqueeze: StaticValueInfo(unsqueeze, shape:[1, 7], dtype:7, no const value.)
E         _val_39: StaticValueInfo(_val_39, shape:[1], dtype:7, no const value.)
E         _val_43: StaticValueInfo(_val_43, shape:[1], dtype:7, no const value.)
E         _val_47: StaticValueInfo(_val_47, shape:[1], dtype:7, no const value.)
E         _val_51: StaticValueInfo(_val_51, shape:[1], dtype:7, no const value.)
E         slice_3: StaticValueInfo(slice_3, shape:[13, 7], dtype:1, no const value.)
E         aten_unsqueeze_84_dim_0: StaticValueInfo(aten_unsqueeze_84_dim_0, shape:[], dtype:7, no const value.)
E         unsqueeze_3: StaticValueInfo(unsqueeze_3, shape:[13, 1, 7], dtype:1, no const value.)
E         aten_unsqueeze_85_dim_0: StaticValueInfo(aten_unsqueeze_85_dim_0, shape:[], dtype:7, no const value.)
E         unsqueeze_4: StaticValueInfo(unsqueeze_4, shape:[13, 1, 1, 7], dtype:1, no const value.)
E         _val_58: StaticValueInfo(_val_58, shape:[1], dtype:7, no const value.)
E         _val_62: StaticValueInfo(_val_62, shape:[1], dtype:7, no const value.)
E         _val_66: StaticValueInfo(_val_66, shape:[1], dtype:7, no const value.)
E         _val_70: StaticValueInfo(_val_70, shape:[1], dtype:7, no const value.)
E         slice_4: StaticValueInfo(slice_4, shape:[13, 1, 1, 7], dtype:1, no const value.)
E         _val_72: StaticValueInfo(_val_72, shape:[4], dtype:7, no const value.)
E         aten_expand_104_size_1: StaticValueInfo(aten_expand_104_size_1, shape:[4], dtype:7, no const value.)
E         expand_1: StaticValueInfo(expand_1, shape:[13, 1, 7, 7], dtype:1, no const value.)
E         _val_74: StaticValueInfo(_val_74, shape:[], dtype:1, no const value.)
E         rsub: StaticValueInfo(rsub, shape:[13, 1, 7, 7], dtype:1, no const value.)
E         _to_copy: StaticValueInfo(_to_copy, shape:[13, 1, 7, 7], dtype:9, no const value.)
E         _val_77: StaticValueInfo(_val_77, shape:[], dtype:1, no const value.)
E         masked_fill_1: StaticValueInfo(masked_fill_1, shape:[13, 1, 7, 7], dtype:1, no const value.)
E         _to_copy_1: StaticValueInfo(_to_copy_1, shape:[13, 1, 7, 7], dtype:9, no const value.)
E         expand_2: StaticValueInfo(expand_2, shape:[13, 1, 7, 7], dtype:1, no const value.)
E         _val_119: StaticValueInfo(_val_119, shape:[], dtype:1, no const value.)
E         masked_fill_2: StaticValueInfo(masked_fill_2, shape:[13, 1, 7, 7], dtype:1, no const value.)
E         model_layers_0_1: StaticValueInfo(model_layers_0_1, shape:[13, 2, 7, 8], dtype:1, no const value.)
E         model_layers_0_1_1: StaticValueInfo(model_layers_0_1_1, shape:[13, 2, 7, 8], dtype:1, no const value.)
E         model_layers_0_1_2: StaticValueInfo(model_layers_0_1_2, shape:[13, 7, 32], dtype:1, no const value.)
E         model_layers_1_1: StaticValueInfo(model_layers_1_1, shape:[13, 2, 7, 8], dtype:1, no const value.)
E         model_layers_1_1_1: StaticValueInfo(model_layers_1_1_1, shape:[13, 2, 7, 8], dtype:1, no const value.)
E         model_layers_1_1_2: StaticValueInfo(model_layers_1_1_2, shape:[13, 7, 32], dtype:1, no const value.)
E         model_final_layernorm_1: StaticValueInfo(model_final_layernorm_1, shape:[13, 7, 32], dtype:1, no const value.)
E       RefAttributes:

@xadupre
Copy link
Member Author

xadupre commented Jun 24, 2024

It is but if you try to load the model with onnxruntime, it fails. The error happens before the optimizer because the model is not valid. It is not detected by check_model or infer_shapes. If you look at the onnx model, you'll see a missing function.

@justinchuby
Copy link
Collaborator

Thanks for the info, I will look deeper. Do you have the name of the missing function and the error stack from onnx runtime?

@justinchuby
Copy link
Collaborator

E   onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Failed to load model with error: ONNX Schema torch_nn_modules_dropout_Dropout_model_embed_dropout_1: failed validating the check: !(it.GetName().empty())

@xadupre
Copy link
Member Author

xadupre commented Jun 24, 2024

Something including dropout in its name. But I don't this matters. I just took the file you modified and replaced return self by return op.Identity(self). We can leave it that way or replace only the instances where the function may be empty. But I think we can have more identity node than expected. It is easy to remove anyway and fast.

@justinchuby
Copy link
Collaborator

Sounds good. The changes look good to me. Thanks!

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. please revert the ci changes thanks!

onnxscript/rewriter/llama_rule_sets_test.py Outdated Show resolved Hide resolved
@justinchuby justinchuby changed the title Revert one change introduced by #1613 and breaking onnxscript/tools/transformers_models/phi_test.py [torchlib] Add the identity nodes back Jun 24, 2024
@justinchuby justinchuby added the topic: torch_lib Related to the torch/aten function lib in development label Jun 24, 2024
@justinchuby justinchuby merged commit be00339 into microsoft:main Jun 25, 2024
37 of 46 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: torch_lib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

2 participants