diff --git a/onnxoptimizer/passes/eliminate_consecutive_idempotent_ops.h b/onnxoptimizer/passes/eliminate_consecutive_idempotent_ops.h index a6fafc1ef..02d7c7b76 100644 --- a/onnxoptimizer/passes/eliminate_consecutive_idempotent_ops.h +++ b/onnxoptimizer/passes/eliminate_consecutive_idempotent_ops.h @@ -24,7 +24,7 @@ struct EliminateConsecutiveIdempotentOps final : public PredicateBasedPass { bool patternMatchPredicate(Node* node) override { static const std::unordered_set idempotent_ops = { - "Ceil", "Floor", "Round", "Relu", "Reshape"}; + "Ceil", "Floor", "Round", "Relu", "Reshape", "Sign"}; for (const auto& op : idempotent_ops) { // TODO: support uses().size() > 1 for ops except Reshape if (CheckKind(node, Symbol(op), 0, Symbol(op)) && diff --git a/onnxoptimizer/test/optimizer_test.py b/onnxoptimizer/test/optimizer_test.py index 5cd6b32fd..941d5eb43 100644 --- a/onnxoptimizer/test/optimizer_test.py +++ b/onnxoptimizer/test/optimizer_test.py @@ -4597,6 +4597,25 @@ def test_eliminate_consecutive_idempotent_op(self): assert optimized_model.graph.node[0].op_type == "Constant" assert optimized_model.graph.node[1].op_type == "Reshape" + def test_eliminate_consecutive_idempotent_sign_op(self): + model = parser.parse_model(""" + < + ir_version: 7, + opset_import:["": 11] + > + agraph (float[1, 2, 3] X) => (float[1, 2, 3] Z) + { + T1 = Sign(X) + T2 = Sign(T1) + T3 = Sign(T2) + Z = Sign(T3) + } + """) + + optimized_model = self._optimized( + model, ['eliminate_consecutive_idempotent_ops', 'eliminate_deadend'], True) + assert len(optimized_model.graph.node) == 1 + assert optimized_model.graph.node[0].op_type == "Sign" if __name__ == "__main__": unittest.main()