Skip to content

Commit

Permalink
Add sign to idempotent ops list
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyamukh6 committed Aug 18, 2023
1 parent 0e49375 commit 8086dfe
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct EliminateConsecutiveIdempotentOps final : public PredicateBasedPass {

bool patternMatchPredicate(Node* node) override {
static const std::unordered_set<std::string> idempotent_ops = {
"Ceil", "Floor", "Round", "Relu", "Reshape"};
"Ceil", "Floor", "Round", "Relu", "Reshape", "Sign"};
for (const auto& op : idempotent_ops) {
// TODO: support uses().size() > 1 for ops except Reshape
if (CheckKind(node, Symbol(op), 0, Symbol(op)) &&
Expand Down
19 changes: 19 additions & 0 deletions onnxoptimizer/test/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4597,6 +4597,25 @@ def test_eliminate_consecutive_idempotent_op(self):
assert optimized_model.graph.node[0].op_type == "Constant"
assert optimized_model.graph.node[1].op_type == "Reshape"

def test_eliminate_consecutive_idempotent_sign_op(self):
model = parser.parse_model("""
<
ir_version: 7,
opset_import:["": 11]
>
agraph (float[1, 2, 3] X) => (float[1, 2, 3] Z)
{
T1 = Sign(X)
T2 = Sign(T1)
T3 = Sign(T2)
Z = Sign(T3)
}
""")

optimized_model = self._optimized(
model, ['eliminate_consecutive_idempotent_ops', 'eliminate_deadend'], True)
assert len(optimized_model.graph.node) == 1
assert optimized_model.graph.node[0].op_type == "Sign"

if __name__ == "__main__":
unittest.main()

0 comments on commit 8086dfe

Please sign in to comment.