From efa64e2d38e86c9c89153fcd1c479e063e189911 Mon Sep 17 00:00:00 2001 From: Ehsan Nadjaran Toosi Date: Thu, 19 Sep 2024 16:24:26 +0100 Subject: [PATCH 1/4] feat: legalization from xten_nn.reduce_mean to aten.mean.dim --- lib/Conversion/XTenNNToTorch.cpp | 16 + .../Conversion/XTenNNToTorch/reduce_mean.mlir | 288 ++++++++++++++++++ 2 files changed, 304 insertions(+) create mode 100644 test/Conversion/XTenNNToTorch/reduce_mean.mlir diff --git a/lib/Conversion/XTenNNToTorch.cpp b/lib/Conversion/XTenNNToTorch.cpp index ed96f980..fa6b79c6 100644 --- a/lib/Conversion/XTenNNToTorch.cpp +++ b/lib/Conversion/XTenNNToTorch.cpp @@ -220,6 +220,21 @@ convTranspose2dToTorch(ConvTransposeOp op, ConvTransposeOp::Adaptor adaptor, ->getResults(); } +std::optional +reduceMeanToTorch(ReduceMeanOp op, ReduceMeanOp::Adaptor adaptor, + ArrayRef types, ValueRange values, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + auto noneConst = rewriter.create(loc); + auto keepdims = + rewriter.create(loc, adaptor.getKeepdims()); + auto axes = Torch::toTorchList(loc, rewriter, adaptor.getAxes().vec()); + return rewriter + .create(loc, types[0], values[0], axes, keepdims, + noneConst) + ->getResults(); +} + std::optional resizeToTorch(ResizeOp op, ResizeOp::Adaptor adaptor, ArrayRef types, ValueRange values, ConversionPatternRewriter &rewriter) { @@ -439,6 +454,7 @@ struct ConvertXTenNNToTorch patterns.add>(context); patterns.add>( context); + patterns.add>(context); if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) signalPassFailure(); } diff --git a/test/Conversion/XTenNNToTorch/reduce_mean.mlir b/test/Conversion/XTenNNToTorch/reduce_mean.mlir new file mode 100644 index 00000000..2a0714ea --- /dev/null +++ b/test/Conversion/XTenNNToTorch/reduce_mean.mlir @@ -0,0 +1,288 @@ +// RUN: aten-opt --convert-xtennn-to-torch -split-input-file %s | FileCheck %s +// REQUIRES: torch + +func.func @reduce_mean_one_axis_keep_dims(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x512x1x8xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x512x1x8xf32> + return %0 : tensor<4x512x1x8xf32> +} + +// CHECK-LABEL: func.func @reduce_mean_one_axis_keep_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x1x8xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,512,1,8],f32> +// CHECK: %[[VAL_7:.*]] = torch_c.to_builtin_tensor %[[VAL_6]] : !torch.vtensor<[4,512,1,8],f32> -> tensor<4x512x1x8xf32> +// CHECK: return %[[VAL_7]] : tensor<4x512x1x8xf32> +// CHECK: } + +// ----- + +func.func @reduce_mean_three_axes_keep_dims(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x1x1x1xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x1x1x1xf32> + return %0 : tensor<4x1x1x1xf32> +} + +// CHECK-LABEL: func.func @reduce_mean_three_axes_keep_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x1x1x1xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_7]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1,1],f32> +// CHECK: %[[VAL_9:.*]] = torch_c.to_builtin_tensor %[[VAL_8]] : !torch.vtensor<[4,1,1,1],f32> -> tensor<4x1x1x1xf32> +// CHECK: return %[[VAL_9]] : tensor<4x1x1x1xf32> +// CHECK: } + +// ----- + +func.func @reduce_mean_all_axes_keep_dims(%arg0: tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> + return %0 : tensor<1x1x1x1xf32> +} + +// CHECK-LABEL: func.func @reduce_mean_all_axes_keep_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_7:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_8]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1,1],f32> +// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[1,1,1,1],f32> -> tensor<1x1x1x1xf32> +// CHECK: return %[[VAL_10]] : tensor<1x1x1x1xf32> +// CHECK: } + +// ----- + +func.func @reduce_mean_one_axis(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x512x8xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 0 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x512x8xf32> + return %0 : tensor<4x512x8xf32> +} + +// CHECK-LABEL: func.func @reduce_mean_one_axis( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x8xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,512,8],f32> +// CHECK: %[[VAL_7:.*]] = torch_c.to_builtin_tensor %[[VAL_6]] : !torch.vtensor<[4,512,8],f32> -> tensor<4x512x8xf32> +// CHECK: return %[[VAL_7]] : tensor<4x512x8xf32> +// CHECK: } + +// ----- + +func.func @reduce_mean_three_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor<4xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 0 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func.func @reduce_mean_three_axes( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_7]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4],f32> +// CHECK: %[[VAL_9:.*]] = torch_c.to_builtin_tensor %[[VAL_8]] : !torch.vtensor<[4],f32> -> tensor<4xf32> +// CHECK: return %[[VAL_9]] : tensor<4xf32> +// CHECK: } + +// ----- + +func.func @reduce_mean_all_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 0 : i64} : (tensor<4x512x256x8xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @reduce_mean_all_axes( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_7:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_8]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[],f32> -> tensor +// CHECK: return %[[VAL_10]] : tensor +// CHECK: } + +// ----- + +func.func @reduce_mean_noop_with_empty_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> + return %0 : tensor<4x512x256x8xf32> +} + +// CHECK-LABEL: func.func @reduce_mean_noop_with_empty_axes( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_5:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[4,512,256,8],f32> -> tensor<4x512x256x8xf32> +// CHECK: return %[[VAL_6]] : tensor<4x512x256x8xf32> +// CHECK: } + +// ----- + +func.func @reduce_meanv13_one_axis_keep_dims(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x512x1x8xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x512x1x8xf32> + return %0 : tensor<4x512x1x8xf32> +} + +// CHECK-LABEL: func.func @reduce_meanv13_one_axis_keep_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x1x8xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,512,1,8],f32> +// CHECK: %[[VAL_7:.*]] = torch_c.to_builtin_tensor %[[VAL_6]] : !torch.vtensor<[4,512,1,8],f32> -> tensor<4x512x1x8xf32> +// CHECK: return %[[VAL_7]] : tensor<4x512x1x8xf32> +// CHECK: } + +// ----- + +func.func @reduce_meanv13_three_axes_keep_dims(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x1x1x1xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x1x1x1xf32> + return %0 : tensor<4x1x1x1xf32> +} + +// CHECK-LABEL: func.func @reduce_meanv13_three_axes_keep_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x1x1x1xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_7]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1,1],f32> +// CHECK: %[[VAL_9:.*]] = torch_c.to_builtin_tensor %[[VAL_8]] : !torch.vtensor<[4,1,1,1],f32> -> tensor<4x1x1x1xf32> +// CHECK: return %[[VAL_9]] : tensor<4x1x1x1xf32> +// CHECK: } + +// ----- + +func.func @reduce_meanv13_all_axes_keep_dims(%arg0: tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> + return %0 : tensor<1x1x1x1xf32> +} + +// CHECK-LABEL: func.func @reduce_meanv13_all_axes_keep_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_7:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_8]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1,1],f32> +// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[1,1,1,1],f32> -> tensor<1x1x1x1xf32> +// CHECK: return %[[VAL_10]] : tensor<1x1x1x1xf32> +// CHECK: } + +// ----- + +func.func @reduce_meanv13_one_axis(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x512x8xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 0 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x512x8xf32> + return %0 : tensor<4x512x8xf32> +} + +// CHECK-LABEL: func.func @reduce_meanv13_one_axis( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x8xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,512,8],f32> +// CHECK: %[[VAL_7:.*]] = torch_c.to_builtin_tensor %[[VAL_6]] : !torch.vtensor<[4,512,8],f32> -> tensor<4x512x8xf32> +// CHECK: return %[[VAL_7]] : tensor<4x512x8xf32> +// CHECK: } + +// ----- + +func.func @reduce_meanv13_three_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor<4xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 0 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func.func @reduce_meanv13_three_axes( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_7]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4],f32> +// CHECK: %[[VAL_9:.*]] = torch_c.to_builtin_tensor %[[VAL_8]] : !torch.vtensor<[4],f32> -> tensor<4xf32> +// CHECK: return %[[VAL_9]] : tensor<4xf32> +// CHECK: } + +// ----- + +func.func @reduce_meanv13_all_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 0 : i64} : (tensor<4x512x256x8xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @reduce_meanv13_all_axes( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_7:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_8]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[],f32> -> tensor +// CHECK: return %[[VAL_10]] : tensor +// CHECK: } + +// ----- + +func.func @reduce_meanv13_noop_with_empty_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> + return %0 : tensor<4x512x256x8xf32> +} + +// CHECK-LABEL: func.func @reduce_meanv13_noop_with_empty_axes( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_7:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_8]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,512,256,8],f32> +// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[4,512,256,8],f32> -> tensor<4x512x256x8xf32> +// CHECK: return %[[VAL_10]] : tensor<4x512x256x8xf32> +// CHECK: } \ No newline at end of file From c88550e7cef4b0aa3098210951d0842e0404bad6 Mon Sep 17 00:00:00 2001 From: Pablo Lanza Serrano Date: Fri, 20 Sep 2024 10:28:23 +0100 Subject: [PATCH 2/4] Infer the output tensor shape --- include/xten/Dialect/XTenNN/IR/XTenNNOps.td | 7 +++- lib/Dialect/XTenNN/IR/XTenNNOps.cpp | 42 ++++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index 3a43d375..8c61b8be 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -562,20 +562,25 @@ def XtenNN_ConvTransposeOp: XTenNN_Op<"ConvTranspose",[Pure, TosaExtension]> { let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }]; } -def XtenNN_ReduceMeanOp: XTenNN_Op<"reduce_mean", [Pure, TosaExtension]> { +def XtenNN_ReduceMeanOp: XTenNN_Op<"reduce_mean", [ + Pure, TosaExtension, + InferTensorTypeAdaptor]> { let summary = "Reduce Mean operation"; let description = [{ This operation is equivalent to `onnx.ReduceMean` and computes the mean of the input tensor's elements along the provided axes. }]; + let arguments = (ins AnyRankedTensor:$input, DenseI64ArrayAttr:$axes, I64Attr:$keepdims ); + let results = (outs AnyRankedTensor:$output ); + let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }]; } diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index 381a8342..c5ef6bc0 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -26,6 +27,7 @@ #include "xten/Dialect/XTenNN/IR/XTenNNBase.h" #include "xten/Dialect/XTenNN/IR/XTenNNOps.h" #include "xten/Dialect/XTenNN/Interfaces/EnclaveOpInterfaces.h" +#include using namespace mlir; using namespace amd::xten_nn; @@ -264,7 +266,9 @@ ParseResult SubgraphOp::parse(OpAsmParser &p, OperationState &result) { return parseEnclaveOp(p, result); } -void SubgraphOp::print(OpAsmPrinter &p) { printEnclaveOp(p, *this); } +void SubgraphOp::print(OpAsmPrinter &p) { + printEnclaveOp(p, *this); +} LogicalResult SubgraphOp::verify() { Block *optBody = this->getOptionalEnclaveBody(); @@ -593,3 +597,39 @@ bool TopK::isCompatibleReturnTypes(mlir::TypeRange l, mlir::TypeRange r) { getElementTypeOrSelf(l[1]) == getElementTypeOrSelf(r[1]); return sameElementType && succeeded(verifyCompatibleShapes(l, r)); } + +LogicalResult ReduceMeanOp::inferReturnTypeComponents( + MLIRContext * /*context*/, std::optional location, + ReduceMeanOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + + auto inTy = cast(adaptor.getInput().getType()); + auto inDims = inTy.getShape(); + auto keepDims = adaptor.getKeepdims(); + + auto axes = adaptor.getAxes(); + llvm::SmallVector newAxes(inDims); + for (auto axis : axes) { + // onnx spec: axis: [-r, r-1] + if (axis < -inTy.getRank() || axis >= inTy.getRank()) { + return emitOptionalError(location, + "expected axis to be within [-rank,rank) (where " + "rank is the rank of the input)"); + } + // normalize axis: [0, r) + if (axis < 0) { + axis += inTy.getRank(); + } + assert((axis >= 0 && axis < inTy.getRank()) && "axis has invalid value"); + + if (keepDims) { + newAxes[axis] = 1; + } else { + newAxes.erase(newAxes.begin() + axis); + } + } + + inferredReturnShapes.push_back( + ShapedTypeComponents(newAxes, inTy.getElementType())); + return success(); +} \ No newline at end of file From 4411faecab48cd84fc2c1fff7182d309e3be22bd Mon Sep 17 00:00:00 2001 From: Ehsan Nadjaran Toosi Date: Fri, 20 Sep 2024 12:23:39 +0100 Subject: [PATCH 3/4] fix: infershape --- lib/Dialect/XTenNN/IR/XTenNNOps.cpp | 39 +++++++++++++++---- .../Conversion/XTenNNToTorch/reduce_mean.mlir | 16 ++++---- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index c5ef6bc0..72e7fe9e 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -604,11 +604,11 @@ LogicalResult ReduceMeanOp::inferReturnTypeComponents( SmallVectorImpl &inferredReturnShapes) { auto inTy = cast(adaptor.getInput().getType()); - auto inDims = inTy.getShape(); auto keepDims = adaptor.getKeepdims(); - auto axes = adaptor.getAxes(); - llvm::SmallVector newAxes(inDims); + + // Sanitize axes + llvm::SmallVector newAxes; for (auto axis : axes) { // onnx spec: axis: [-r, r-1] if (axis < -inTy.getRank() || axis >= inTy.getRank()) { @@ -616,20 +616,45 @@ LogicalResult ReduceMeanOp::inferReturnTypeComponents( "expected axis to be within [-rank,rank) (where " "rank is the rank of the input)"); } + // normalize axis: [0, r) if (axis < 0) { axis += inTy.getRank(); } + assert((axis >= 0 && axis < inTy.getRank()) && "axis has invalid value"); + newAxes.push_back(axis); + } - if (keepDims) { - newAxes[axis] = 1; + SmallVector outputShape; + auto inputShape = inTy.getShape(); + for (auto [idx, dim] : llvm::enumerate(inputShape)) { + if (llvm::is_contained(axes, idx)) { + if (keepDims) { + outputShape.push_back(1); + } } else { - newAxes.erase(newAxes.begin() + axis); + outputShape.push_back(dim); } } + llvm::errs() << keepDims << "############\n"; + for (auto elem : newAxes) { + llvm::errs() << elem << ","; + } + llvm::errs() << "\n"; + + for (auto elem : inputShape) { + llvm::errs() << elem << ","; + } + llvm::errs() << "\n"; + + for (auto elem : outputShape) { + llvm::errs() << elem << ","; + } + llvm::errs() << "\n"; + inferredReturnShapes.push_back( - ShapedTypeComponents(newAxes, inTy.getElementType())); + ShapedTypeComponents(outputShape, inTy.getElementType())); return success(); } \ No newline at end of file diff --git a/test/Conversion/XTenNNToTorch/reduce_mean.mlir b/test/Conversion/XTenNNToTorch/reduce_mean.mlir index 2a0714ea..45d4a1e7 100644 --- a/test/Conversion/XTenNNToTorch/reduce_mean.mlir +++ b/test/Conversion/XTenNNToTorch/reduce_mean.mlir @@ -267,13 +267,13 @@ func.func @reduce_meanv13_all_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor) -> tensor<4x512x256x8xf32> { - %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> - return %0 : tensor<4x512x256x8xf32> +func.func @reduce_meanv13_noop_with_empty_axes(%arg0: tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> { + %0 = xten_nn.reduce_mean %arg0 {axes = array, keepdims = 1 : i64} : (tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> + return %0 : tensor<1x1x1x1xf32> } // CHECK-LABEL: func.func @reduce_meanv13_noop_with_empty_axes( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<1x1x1x1xf32> attributes {torch.onnx_meta.opset_version = 19 : si64} { // CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<4x512x256x8xf32> -> !torch.vtensor<[4,512,256,8],f32> // CHECK: %[[VAL_2:.*]] = torch.constant.none // CHECK: %[[VAL_3:.*]] = torch.constant.bool true @@ -282,7 +282,7 @@ func.func @reduce_meanv13_noop_with_empty_axes(%arg0: tensor<4x512x256x8xf32>) - // CHECK: %[[VAL_6:.*]] = torch.constant.int 2 // CHECK: %[[VAL_7:.*]] = torch.constant.int 3 // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_9:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_8]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,512,256,8],f32> -// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[4,512,256,8],f32> -> tensor<4x512x256x8xf32> -// CHECK: return %[[VAL_10]] : tensor<4x512x256x8xf32> -// CHECK: } \ No newline at end of file +// CHECK: %[[VAL_9:.*]] = torch.aten.mean.dim %[[VAL_1]], %[[VAL_8]], %[[VAL_3]], %[[VAL_2]] : !torch.vtensor<[4,512,256,8],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1,1],f32> +// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[1,1,1,1],f32> -> tensor<1x1x1x1xf32> +// CHECK: return %[[VAL_10]] : tensor<1x1x1x1xf32> +// CHECK: } From 6f7009dd8f9d4c5ae32600b37b9dc6a6d6883b71 Mon Sep 17 00:00:00 2001 From: Pablo Lanza Serrano Date: Fri, 20 Sep 2024 13:33:21 +0100 Subject: [PATCH 4/4] Removed debug messages --- lib/Dialect/XTenNN/IR/XTenNNOps.cpp | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index 72e7fe9e..cc4c67af 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -638,22 +638,6 @@ LogicalResult ReduceMeanOp::inferReturnTypeComponents( } } - llvm::errs() << keepDims << "############\n"; - for (auto elem : newAxes) { - llvm::errs() << elem << ","; - } - llvm::errs() << "\n"; - - for (auto elem : inputShape) { - llvm::errs() << elem << ","; - } - llvm::errs() << "\n"; - - for (auto elem : outputShape) { - llvm::errs() << elem << ","; - } - llvm::errs() << "\n"; - inferredReturnShapes.push_back( ShapedTypeComponents(outputShape, inTy.getElementType())); return success();