From d3eaef11fc7a7f8078ea0d3e25adf8cc766c63be Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 8 May 2023 11:28:51 +0800 Subject: [PATCH] add teset --- python/tests/op_mappers/test_matmul_op.py | 53 +++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 python/tests/op_mappers/test_matmul_op.py diff --git a/python/tests/op_mappers/test_matmul_op.py b/python/tests/op_mappers/test_matmul_op.py new file mode 100644 index 0000000000..fe64e3014e --- /dev/null +++ b/python/tests/op_mappers/test_matmul_op.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2023 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_mapper_test import OpMapperTest, logger +import paddle + + +class TestMatmulOp(OpMapperTest): + def init_input_data(self): + self.feed_data = { + "x": self.random([16, 32], "float32"), + "y": self.random([32, 16], "float32") + } + + def set_op_type(self): + return "matmul" + + def set_op_inputs(self): + x = paddle.static.data('X', self.feed_data["x"].shape, + self.feed_data["x"].dtype) + x = paddle.static.data('Y', self.feed_data["y"].shape, + self.feed_data["Y"].dtype) + return {'X': [x], 'Y': [y]} + + def set_op_attrs(self): + return { + "trans_x": False, + "trans_y": False + } + + def set_op_outputs(self): + return {'Out': [str(self.feed_data['x'].dtype)]} + + def test_check_results(self): + self.check_outputs_and_grads() + +if __name__ == "__main__": + unittest.main()