diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py index 04a7f4f69..d65f01c8d 100644 --- a/onnxscript/rewriter/generic_pattern_test.py +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -281,6 +281,41 @@ def apply_pattern(op, x, **_): self.assertEqual(len(graph.node), 2) self.assertEqual(graph.node[0].op_type, "SinCos") + @unittest.skip("Input variable reuse not supported yet") + def test_shared_root_value_extra_use(self): + def match_pattern(op, x): + t1 = op.Sin(x) + t2 = op.Cos(x) + return t1, t2 + + def apply_pattern(op, x, **_): + return op.SinCos(x, domain="com.microsoft", outputs=2) + + rule = pattern.RewriteRule( + match_pattern, + apply_pattern, + matcher=generic_pattern.GenericPatternMatcher, + ) + model_proto = onnx.parser.parse_model( + """ + + agraph (float[N] y) => (float[N] z) + { + temp1 = Sin(y) + temp2 = Cos(y) + w = Add(temp1, temp2) + z = Mul(w, y) + } + """ + ) + onnx.checker.check_model(model_proto) + model = onnx.shape_inference.infer_shapes(model_proto) + ir_model = ir.serde.deserialize_model(model) + rule.apply_to_model(ir_model) + graph = ir_model.graph + self.assertEqual(len(graph), 3) + self.assertEqual(graph.node[0].op_type, "SinCos") + def test_rotary_embedding(self): # The test work on a model if it has the expected name. # A dummy model is used if not present (not implemented yet).